Commit 50d1d812 authored by IlyaOvodov's avatar IlyaOvodov

Save models to dir. Verbose argument.

parent 8cae81d8
import copy import copy
import math import math
import os
import torch import torch
import ignite import ignite
from ignite.engine import Events from ignite.engine import Events
...@@ -65,12 +66,14 @@ class IgniteTimes: ...@@ -65,12 +66,14 @@ class IgniteTimes:
class BestModelBuffer: class BestModelBuffer:
def __init__(self, model, metric_name, params, minimize = True, save_to_file = True): def __init__(self, model, metric_name, params, minimize = True, save_to_file = True, save_to_dir_suffix = None, verbose = 1):
self.model = model self.model = model
assert metric_name assert metric_name
self.metric_name = metric_name self.metric_name = metric_name
assert minimize == True, "Not implemented" assert minimize == True, "Not implemented"
self.save_to_file = save_to_file self.save_to_file = save_to_file
self.save_to_dir_suffix = save_to_dir_suffix
self.verbose = verbose
self.params = params self.params = params
self.reset() self.reset()
...@@ -85,17 +88,25 @@ class BestModelBuffer: ...@@ -85,17 +88,25 @@ class BestModelBuffer:
self.best_score = engine.state.metrics[self.metric_name] self.best_score = engine.state.metrics[self.metric_name]
self.best_dict = copy.deepcopy(self.model.state_dict()) self.best_dict = copy.deepcopy(self.model.state_dict())
self.best_epoch = engine.state.epoch self.best_epoch = engine.state.epoch
if self.verbose:
print('model for {}={} dumped'.format(self.metric_name, self.best_score)) print('model for {}={} dumped'.format(self.metric_name, self.best_score))
if self.save_to_file: if self.save_to_file:
self.save_model() self.save_model()
def save_model(self, suffix = ""): def save_model(self, file_suffix = "model"):
torch.save(self.best_dict, self.params.get_base_filename() + suffix + '.t7') if self.save_to_dir_suffix is not None:
dir_name = self.params.get_base_filename() + self.save_to_dir_suffix
os.makedirs(dir_name, exist_ok=True)
file_name = os.path.join(dir_name, file_suffix + '.t7')
else:
file_name = self.params.get_base_filename() + file_suffix + '.t7'
torch.save(self.best_dict, file_name)
def restore(self, model = None): def restore(self, model = None):
assert self.best_dict is not None assert self.best_dict is not None
if model is None: if model is None:
model = self.model model = self.model
if self.verbose:
print('model for {}={} on epoch {} restored'.format(self.metric_name, self.best_score, self.best_epoch)) print('model for {}={} on epoch {} restored'.format(self.metric_name, self.best_score, self.best_epoch))
model.load_state_dict(self.best_dict) model.load_state_dict(self.best_dict)
...@@ -175,7 +186,8 @@ class ClrScheduler: ...@@ -175,7 +186,8 @@ class ClrScheduler:
self.iterations_per_epoch = len(train_loader) self.iterations_per_epoch = len(train_loader)
self.min_lr = params.clr.min_lr self.min_lr = params.clr.min_lr
self.max_lr = params.clr.max_lr self.max_lr = params.clr.max_lr
self.best_model_buffer = BestModelBuffer(model, metric_name, params, minimize = minimize, save_to_file = False) self.best_model_buffer = BestModelBuffer(model, metric_name, params, minimize = minimize, save_to_file = False,
save_to_dir_suffix = '.clr_models', verbose = 0)
if engine: if engine:
self.attach(engine) self.attach(engine)
...@@ -188,7 +200,7 @@ class ClrScheduler: ...@@ -188,7 +200,7 @@ class ClrScheduler:
if (self.cycle_index == 0 and self.iter_index == self.params.clr.warmup_epochs*self.iterations_per_epoch if (self.cycle_index == 0 and self.iter_index == self.params.clr.warmup_epochs*self.iterations_per_epoch
or self.cycle_index > 0 and self.iter_index == self.params.clr.period_epochs*self.iterations_per_epoch): or self.cycle_index > 0 and self.iter_index == self.params.clr.period_epochs*self.iterations_per_epoch):
if self.cycle_index > 0: if self.cycle_index > 0:
self.best_model_buffer.save_model('.'+str(self.cycle_index)) self.best_model_buffer.save_model('{:03}'.format(self.cycle_index))
self.best_model_buffer.restore() self.best_model_buffer.restore()
self.best_model_buffer.reset() self.best_model_buffer.reset()
self.min_lr *= self.params.clr.scale_min_lr self.min_lr *= self.params.clr.scale_min_lr
...@@ -247,6 +259,6 @@ def create_supervised_trainer(model, optimizer, loss_fn, metrics={}, ...@@ -247,6 +259,6 @@ def create_supervised_trainer(model, optimizer, loss_fn, metrics={},
engine = ignite.engine.Engine(_update) engine = ignite.engine.Engine(_update)
for name, metric in metrics.items(): for name, metric in metrics.items():
metric.attach(engine, name) metric.attach(engine, 'train:' + name)
return engine return engine
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment