Commit 93d8841b authored by iovodov's avatar iovodov

Merge branch 'master' of https://github.com/IlyaOvodov/OvoTools into adaptiveLR

parents 3dc5aaa0 298083c6
import copy import copy
import math import math
import os import os
import subprocess
import torch import torch
import ignite import ignite
from ignite.engine import Events from ignite.engine import Events
...@@ -70,7 +71,7 @@ class BestModelBuffer: ...@@ -70,7 +71,7 @@ class BestModelBuffer:
self.model = model self.model = model
assert metric_name assert metric_name
self.metric_name = metric_name self.metric_name = metric_name
assert minimize == True, "Not implemented" self.minimize = 1 if minimize else -1
self.save_to_file = save_to_file self.save_to_file = save_to_file
self.verbose = verbose self.verbose = verbose
self.params = params self.params = params
...@@ -83,7 +84,7 @@ class BestModelBuffer: ...@@ -83,7 +84,7 @@ class BestModelBuffer:
def __call__(self, engine): def __call__(self, engine):
assert self.metric_name in engine.state.metrics.keys(), "{} {}".format(self.metric_name, engine.state.metrics.keys()) assert self.metric_name in engine.state.metrics.keys(), "{} {}".format(self.metric_name, engine.state.metrics.keys())
if self.best_score is None or self.best_score > engine.state.metrics[self.metric_name]: if self.best_score is None or self.best_score*self.minimize > engine.state.metrics[self.metric_name]*self.minimize:
self.best_score = engine.state.metrics[self.metric_name] self.best_score = engine.state.metrics[self.metric_name]
self.best_dict = copy.deepcopy(self.model.state_dict()) self.best_dict = copy.deepcopy(self.model.state_dict())
self.best_epoch = engine.state.epoch self.best_epoch = engine.state.epoch
...@@ -94,7 +95,7 @@ class BestModelBuffer: ...@@ -94,7 +95,7 @@ class BestModelBuffer:
def save_model(self, rel_dir = "models", filename = None): def save_model(self, rel_dir = "models", filename = None):
if filename is None: if filename is None:
filename = self.params.get_model_name() + ".t7" filename = "best.t7"
file_name = os.path.join(self.params.get_base_filename(), rel_dir, filename) file_name = os.path.join(self.params.get_base_filename(), rel_dir, filename)
dir_name = os.path.dirname(file_name) dir_name = os.path.dirname(file_name)
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
...@@ -124,7 +125,7 @@ class LogTrainingResults: ...@@ -124,7 +125,7 @@ class LogTrainingResults:
self.best_model_buffer = best_model_buffer self.best_model_buffer = best_model_buffer
self.params = params self.params = params
if filename is None: if filename is None:
filename = self.params.get_model_name() + ".log" filename = "log.log"
self.file_name = os.path.join(self.params.get_base_filename(), rel_dir, filename) self.file_name = os.path.join(self.params.get_base_filename(), rel_dir, filename)
self.calls_count = collections.defaultdict(int) self.calls_count = collections.defaultdict(int)
...@@ -137,7 +138,7 @@ class LogTrainingResults: ...@@ -137,7 +138,7 @@ class LogTrainingResults:
duty_cycles = self.duty_cycles duty_cycles = self.duty_cycles
elif isinstance(self.duty_cycles, dict): elif isinstance(self.duty_cycles, dict):
duty_cycles = self.duty_cycles[event] duty_cycles = self.duty_cycles[event]
if self.calls_count[event] % duty_cycles != 0: if self.calls_count[event] % duty_cycles != 0 and self.calls_count[event] != 1: #always run 1st time to provide statistics
return return
for key,loader in self.loaders_dict.items(): for key,loader in self.loaders_dict.items():
self.evaluator.run(loader) self.evaluator.run(loader)
...@@ -171,10 +172,15 @@ class TensorBoardLogger: ...@@ -171,10 +172,15 @@ class TensorBoardLogger:
self.call_count = 0 self.call_count = 0
trainer_engine.add_event_handler(Events.COMPLETED, self.on_completed) trainer_engine.add_event_handler(Events.COMPLETED, self.on_completed)
def start_server(self, port): def start_server(self, port, start_it = False):
cmd = r"tensorboard --host 127.0.0.1 --port {port} --logdir {dir}".format(port=port, dir=self.log_dir) #cmd = r"tensorboard --host 127.0.0.1 --port {port} --logdir {dir}".format(port=port, dir=self.log_dir)
print(cmd) #print(cmd)
os.popen(cmd) #os.popen(cmd)
cmd = r'tensorboard --host 127.0.0.1 --port {port} --logdir ""'.format(port=port).split(' ')
cmd[-1] = self.log_dir # can contain spaces
print(' '.join(cmd))
if start_it:
subprocess.Popen(cmd)
def on_completed(self, engine): def on_completed(self, engine):
self.writer.close() self.writer.close()
...@@ -228,7 +234,7 @@ class ClrScheduler: ...@@ -228,7 +234,7 @@ class ClrScheduler:
if (self.cycle_index == 0 and self.iter_index == self.params.clr.warmup_epochs*self.iterations_per_epoch if (self.cycle_index == 0 and self.iter_index == self.params.clr.warmup_epochs*self.iterations_per_epoch
or self.cycle_index > 0 and self.iter_index == self.params.clr.period_epochs*self.iterations_per_epoch): or self.cycle_index > 0 and self.iter_index == self.params.clr.period_epochs*self.iterations_per_epoch):
if self.cycle_index > 0: if self.cycle_index > 0:
self.best_model_buffer.save_model(rel_dir = 'models', filename = 'clr.{:03}'.format(self.cycle_index)) self.best_model_buffer.save_model(rel_dir = 'models', filename = 'clr.{:03}.t7'.format(self.cycle_index))
self.best_model_buffer.restore() self.best_model_buffer.restore()
self.best_model_buffer.reset() self.best_model_buffer.reset()
self.min_lr *= self.params.clr.get('scale_min_lr', 1) self.min_lr *= self.params.clr.get('scale_min_lr', 1)
......
...@@ -71,7 +71,7 @@ class AttrDict(dict): ...@@ -71,7 +71,7 @@ class AttrDict(dict):
assert self.has('data_root') assert self.has('data_root')
return os.path.join(self.data_root, self.get_model_name()) return os.path.join(self.data_root, self.get_model_name())
def save(self, base_fn = None, verbose = 1, can_overwrite = False): def save(self, base_fn = None, verbose = 1, can_overwrite = False, create_dirs = False):
''' '''
save to file adding '.param.txt' to name save to file adding '.param.txt' to name
''' '''
...@@ -80,6 +80,9 @@ class AttrDict(dict): ...@@ -80,6 +80,9 @@ class AttrDict(dict):
params_fn = base_fn + '.param.txt' params_fn = base_fn + '.param.txt'
if not can_overwrite: if not can_overwrite:
assert not os.path.exists(params_fn), "Can't save parameters to {}: File exists".format(params_fn) assert not os.path.exists(params_fn), "Can't save parameters to {}: File exists".format(params_fn)
if create_dirs:
dir_name = os.path.dirname(params_fn)
os.makedirs(dir_name, exist_ok=True)
with open(params_fn, 'w+') as f: with open(params_fn, 'w+') as f:
s = repr(self) s = repr(self)
s = s + '\nhash: ' + self.hash() s = s + '\nhash: ' + self.hash()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment