Commit 93d8841b authored by iovodov's avatar iovodov

Merge branch 'master' of https://github.com/IlyaOvodov/OvoTools into adaptiveLR

parents 3dc5aaa0 298083c6
import copy
import math
import os
import subprocess
import torch
import ignite
from ignite.engine import Events
......@@ -70,7 +71,7 @@ class BestModelBuffer:
self.model = model
assert metric_name
self.metric_name = metric_name
assert minimize == True, "Not implemented"
self.minimize = 1 if minimize else -1
self.save_to_file = save_to_file
self.verbose = verbose
self.params = params
......@@ -83,7 +84,7 @@ class BestModelBuffer:
def __call__(self, engine):
assert self.metric_name in engine.state.metrics.keys(), "{} {}".format(self.metric_name, engine.state.metrics.keys())
if self.best_score is None or self.best_score > engine.state.metrics[self.metric_name]:
if self.best_score is None or self.best_score*self.minimize > engine.state.metrics[self.metric_name]*self.minimize:
self.best_score = engine.state.metrics[self.metric_name]
self.best_dict = copy.deepcopy(self.model.state_dict())
self.best_epoch = engine.state.epoch
......@@ -94,7 +95,7 @@ class BestModelBuffer:
def save_model(self, rel_dir = "models", filename = None):
if filename is None:
filename = self.params.get_model_name() + ".t7"
filename = "best.t7"
file_name = os.path.join(self.params.get_base_filename(), rel_dir, filename)
dir_name = os.path.dirname(file_name)
os.makedirs(dir_name, exist_ok=True)
......@@ -124,7 +125,7 @@ class LogTrainingResults:
self.best_model_buffer = best_model_buffer
self.params = params
if filename is None:
filename = self.params.get_model_name() + ".log"
filename = "log.log"
self.file_name = os.path.join(self.params.get_base_filename(), rel_dir, filename)
self.calls_count = collections.defaultdict(int)
......@@ -137,7 +138,7 @@ class LogTrainingResults:
duty_cycles = self.duty_cycles
elif isinstance(self.duty_cycles, dict):
duty_cycles = self.duty_cycles[event]
if self.calls_count[event] % duty_cycles != 0:
if self.calls_count[event] % duty_cycles != 0 and self.calls_count[event] != 1: #always run 1st time to provide statistics
return
for key,loader in self.loaders_dict.items():
self.evaluator.run(loader)
......@@ -171,10 +172,15 @@ class TensorBoardLogger:
self.call_count = 0
trainer_engine.add_event_handler(Events.COMPLETED, self.on_completed)
def start_server(self, port):
cmd = r"tensorboard --host 127.0.0.1 --port {port} --logdir {dir}".format(port=port, dir=self.log_dir)
print(cmd)
os.popen(cmd)
def start_server(self, port, start_it = False):
#cmd = r"tensorboard --host 127.0.0.1 --port {port} --logdir {dir}".format(port=port, dir=self.log_dir)
#print(cmd)
#os.popen(cmd)
cmd = r'tensorboard --host 127.0.0.1 --port {port} --logdir ""'.format(port=port).split(' ')
cmd[-1] = self.log_dir # can contain spaces
print(' '.join(cmd))
if start_it:
subprocess.Popen(cmd)
def on_completed(self, engine):
self.writer.close()
......@@ -228,7 +234,7 @@ class ClrScheduler:
if (self.cycle_index == 0 and self.iter_index == self.params.clr.warmup_epochs*self.iterations_per_epoch
or self.cycle_index > 0 and self.iter_index == self.params.clr.period_epochs*self.iterations_per_epoch):
if self.cycle_index > 0:
self.best_model_buffer.save_model(rel_dir = 'models', filename = 'clr.{:03}'.format(self.cycle_index))
self.best_model_buffer.save_model(rel_dir = 'models', filename = 'clr.{:03}.t7'.format(self.cycle_index))
self.best_model_buffer.restore()
self.best_model_buffer.reset()
self.min_lr *= self.params.clr.get('scale_min_lr', 1)
......
......@@ -71,7 +71,7 @@ class AttrDict(dict):
assert self.has('data_root')
return os.path.join(self.data_root, self.get_model_name())
def save(self, base_fn = None, verbose = 1, can_overwrite = False):
def save(self, base_fn = None, verbose = 1, can_overwrite = False, create_dirs = False):
'''
save to file adding '.param.txt' to name
'''
......@@ -80,6 +80,9 @@ class AttrDict(dict):
params_fn = base_fn + '.param.txt'
if not can_overwrite:
assert not os.path.exists(params_fn), "Can't save parameters to {}: File exists".format(params_fn)
if create_dirs:
dir_name = os.path.dirname(params_fn)
os.makedirs(dir_name, exist_ok=True)
with open(params_fn, 'w+') as f:
s = repr(self)
s = s + '\nhash: ' + self.hash()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment