Commit 298083c6 authored by Ilya Ovodov's avatar Ilya Ovodov

BestModelBuffer max implemented, some fixes

parent 477e1590
...@@ -71,7 +71,7 @@ class BestModelBuffer: ...@@ -71,7 +71,7 @@ class BestModelBuffer:
self.model = model self.model = model
assert metric_name assert metric_name
self.metric_name = metric_name self.metric_name = metric_name
assert minimize == True, "Not implemented" self.minimize = 1 if minimize else -1
self.save_to_file = save_to_file self.save_to_file = save_to_file
self.verbose = verbose self.verbose = verbose
self.params = params self.params = params
...@@ -84,7 +84,7 @@ class BestModelBuffer: ...@@ -84,7 +84,7 @@ class BestModelBuffer:
def __call__(self, engine): def __call__(self, engine):
assert self.metric_name in engine.state.metrics.keys(), "{} {}".format(self.metric_name, engine.state.metrics.keys()) assert self.metric_name in engine.state.metrics.keys(), "{} {}".format(self.metric_name, engine.state.metrics.keys())
if self.best_score is None or self.best_score > engine.state.metrics[self.metric_name]: if self.best_score is None or self.best_score*self.minimize > engine.state.metrics[self.metric_name]*self.minimize:
self.best_score = engine.state.metrics[self.metric_name] self.best_score = engine.state.metrics[self.metric_name]
self.best_dict = copy.deepcopy(self.model.state_dict()) self.best_dict = copy.deepcopy(self.model.state_dict())
self.best_epoch = engine.state.epoch self.best_epoch = engine.state.epoch
...@@ -95,7 +95,7 @@ class BestModelBuffer: ...@@ -95,7 +95,7 @@ class BestModelBuffer:
def save_model(self, rel_dir = "models", filename = None): def save_model(self, rel_dir = "models", filename = None):
if filename is None: if filename is None:
filename = self.params.get_model_name() + ".t7" filename = "best.t7"
file_name = os.path.join(self.params.get_base_filename(), rel_dir, filename) file_name = os.path.join(self.params.get_base_filename(), rel_dir, filename)
dir_name = os.path.dirname(file_name) dir_name = os.path.dirname(file_name)
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
...@@ -234,7 +234,7 @@ class ClrScheduler: ...@@ -234,7 +234,7 @@ class ClrScheduler:
if (self.cycle_index == 0 and self.iter_index == self.params.clr.warmup_epochs*self.iterations_per_epoch if (self.cycle_index == 0 and self.iter_index == self.params.clr.warmup_epochs*self.iterations_per_epoch
or self.cycle_index > 0 and self.iter_index == self.params.clr.period_epochs*self.iterations_per_epoch): or self.cycle_index > 0 and self.iter_index == self.params.clr.period_epochs*self.iterations_per_epoch):
if self.cycle_index > 0: if self.cycle_index > 0:
self.best_model_buffer.save_model(rel_dir = 'models', filename = 'clr.{:03}'.format(self.cycle_index)) self.best_model_buffer.save_model(rel_dir = 'models', filename = 'clr.{:03}.t7'.format(self.cycle_index))
self.best_model_buffer.restore() self.best_model_buffer.restore()
self.best_model_buffer.reset() self.best_model_buffer.reset()
self.min_lr *= self.params.clr.get('scale_min_lr', 1) self.min_lr *= self.params.clr.get('scale_min_lr', 1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment