Commit 50d1d812 authored by IlyaOvodov's avatar IlyaOvodov

Save models to dir. Verbose argument.

parent 8cae81d8
import copy
import math
import os
import torch
import ignite
from ignite.engine import Events
......@@ -65,12 +66,14 @@ class IgniteTimes:
class BestModelBuffer:
def __init__(self, model, metric_name, params, minimize = True, save_to_file = True):
def __init__(self, model, metric_name, params, minimize = True, save_to_file = True, save_to_dir_suffix = None, verbose = 1):
self.model = model
assert metric_name
self.metric_name = metric_name
assert minimize == True, "Not implemented"
self.save_to_file = save_to_file
self.save_to_dir_suffix = save_to_dir_suffix
self.verbose = verbose
self.params = params
self.reset()
......@@ -85,18 +88,26 @@ class BestModelBuffer:
self.best_score = engine.state.metrics[self.metric_name]
self.best_dict = copy.deepcopy(self.model.state_dict())
self.best_epoch = engine.state.epoch
print('model for {}={} dumped'.format(self.metric_name, self.best_score))
if self.verbose:
print('model for {}={} dumped'.format(self.metric_name, self.best_score))
if self.save_to_file:
self.save_model()
def save_model(self, suffix = ""):
torch.save(self.best_dict, self.params.get_base_filename() + suffix + '.t7')
def save_model(self, file_suffix = "model"):
if self.save_to_dir_suffix is not None:
dir_name = self.params.get_base_filename() + self.save_to_dir_suffix
os.makedirs(dir_name, exist_ok=True)
file_name = os.path.join(dir_name, file_suffix + '.t7')
else:
file_name = self.params.get_base_filename() + file_suffix + '.t7'
torch.save(self.best_dict, file_name)
def restore(self, model = None):
assert self.best_dict is not None
if model is None:
model = self.model
print('model for {}={} on epoch {} restored'.format(self.metric_name, self.best_score, self.best_epoch))
if self.verbose:
print('model for {}={} on epoch {} restored'.format(self.metric_name, self.best_score, self.best_epoch))
model.load_state_dict(self.best_dict)
......@@ -175,7 +186,8 @@ class ClrScheduler:
self.iterations_per_epoch = len(train_loader)
self.min_lr = params.clr.min_lr
self.max_lr = params.clr.max_lr
self.best_model_buffer = BestModelBuffer(model, metric_name, params, minimize = minimize, save_to_file = False)
self.best_model_buffer = BestModelBuffer(model, metric_name, params, minimize = minimize, save_to_file = False,
save_to_dir_suffix = '.clr_models', verbose = 0)
if engine:
self.attach(engine)
......@@ -188,7 +200,7 @@ class ClrScheduler:
if (self.cycle_index == 0 and self.iter_index == self.params.clr.warmup_epochs*self.iterations_per_epoch
or self.cycle_index > 0 and self.iter_index == self.params.clr.period_epochs*self.iterations_per_epoch):
if self.cycle_index > 0:
self.best_model_buffer.save_model('.'+str(self.cycle_index))
self.best_model_buffer.save_model('{:03}'.format(self.cycle_index))
self.best_model_buffer.restore()
self.best_model_buffer.reset()
self.min_lr *= self.params.clr.scale_min_lr
......@@ -247,6 +259,6 @@ def create_supervised_trainer(model, optimizer, loss_fn, metrics={},
engine = ignite.engine.Engine(_update)
for name, metric in metrics.items():
metric.attach(engine, name)
metric.attach(engine, 'train:' + name)
return engine
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment