Commit d753605f authored by Ilya Ovodov's avatar Ilya Ovodov

fixes for ignite checkpoint

parent 0109c4c6
......@@ -8,6 +8,7 @@ from ignite.engine import Events
import collections
import time
import tensorboardX
from typing import Dict, Any
......@@ -77,13 +78,21 @@ class BestModelBuffer:
self.params = params
self.reset()
def state_dict(self) -> Dict[str, Any]:
return {key: value for key, value in self.__dict__.items() if key not in {"model", "params"}}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.__dict__.update(state_dict)
def reset(self):
self.best_dict = None
self.best_score = None
self.best_epoch = None
def __call__(self, engine):
assert self.metric_name in engine.state.metrics.keys(), "{} {}".format(self.metric_name, engine.state.metrics.keys())
if self.metric_name not in engine.state.metrics.keys():
print("Warning: metric {} not in {}".format(self.metric_name, engine.state.metrics.keys()))
return
if self.best_score is None or self.best_score*self.minimize > engine.state.metrics[self.metric_name]*self.minimize:
self.best_score = engine.state.metrics[self.metric_name]
self.best_dict = copy.deepcopy(self.model.state_dict())
......@@ -225,6 +234,16 @@ class ClrScheduler:
if engine:
self.attach(engine)
def state_dict(self) -> Dict[str, Any]:
own_dict = {key: value for key, value in self.__dict__.items() if key not in {"optimizer", "params", "best_model_buffer"}}
own_dict["best_model_buffer"] = self.best_model_buffer.state_dict()
return own_dict
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.best_model_buffer.load_state_dict(state_dict["best_model_buffer"])
del state_dict["best_model_buffer"]
self.__dict__.update(state_dict)
def attach(self, engine):
engine.add_event_handler(Events.EPOCH_STARTED, self.upd_lr_epoch)
engine.add_event_handler(Events.ITERATION_STARTED, self.upd_lr)
......
......@@ -2,6 +2,7 @@ import hashlib
import json
import ast
import os
from typing import Dict, Any
from collections import OrderedDict
......@@ -141,6 +142,12 @@ class AttrDict(OrderedDict):
print('loaded from ' + str(params_fn))
return params
def state_dict(self) -> Dict[str, Any]:
return {key: value for key, value in self.__dict__.items()}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.__dict__.update(state_dict)
if __name__=='__main__':
m = AttrDict(
......
......@@ -4,6 +4,11 @@ import torch
SEED = 241075
# TODO
#PYHTONHASHSEED
#https://github.com/n01z3/kaggle-pneumothorax-segmentation/blob/master/n15_train.py#L32-L49
#https://pytorch.org/docs/stable/generated/torch.set_deterministic.html#torch.set_deterministic
def set_reproducibility(seed = SEED):
'''
attempts to make calculations reproducible
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment