Commit f21f1e8e authored by IlyaOvodov's avatar IlyaOvodov

data_root at AttrDict, ignite_tools, pytorch_tools

parent d780d2bc
import copy
import torch
from ignite.engine import Events
class BestModelBuffer:
def __init__(self, model, metric_name, params, minimize = True, save_to_file = True):
self.model = model
assert metric_name
self.metric_name = metric_name
assert minimize == True, "Not implemented"
self.save_to_file = save_to_file
self.params = params
self.reset()
def reset(self):
self.best_dict = None
self.best_score = None
self.best_epoch = None
def save_if_best(self, engine):
assert self.metric_name in engine.state.metrics.keys(), "{} {}".format(self.metric_name, engine.state.metrics.keys())
if self.best_score is None or self.best_score > engine.state.metrics[self.metric_name]:
self.best_score = engine.state.metrics[self.metric_name]
self.best_dict = copy.deepcopy(self.model.state_dict())
self.best_epoch = engine.state.epoch
print('model for {}={} dumped'.format(self.metric_name, self.best_score))
if self.save_to_file:
torch.save(self.best_dict, self.params.get_base_filename() + '.t7')
def restore(self, model = None):
assert self.best_dict is not None
if model is None:
model = self.model
print('model for {}={} on epoch {} restored'.format(self.metric_name, self.best_score, self.best_epoch))
model.load_state_dict(self.best_dict)
class LogTrainingResults:
def __init__(self, evaluator, loaders_dict, best_model_buffer, params):
self.evaluator = evaluator
self.loaders_dict = loaders_dict
self.best_model_buffer = best_model_buffer
self.params = params
def __call__(self, engine, event):
if event == Events.ITERATION_COMPLETED and engine.state.epoch != 1:
return
for key,loader in self.loaders_dict.items():
self.evaluator.run(loader)
for k,v in self.evaluator.state.metrics.items():
engine.state.metrics[key+'.'+k] = v
self.best_model_buffer.save_if_best(engine)
if event == Events.ITERATION_COMPLETED:
str = "Epoch:{}.{}\t".format(engine.state.epoch, engine.state.iteration)
else:
str = "Epoch:{}\t".format(engine.state.epoch)
str += '\t'.join(['{}:{:.3f}'.format(k,v) for k,v in engine.state.metrics.items()])
print(str)
with open(self.params.get_base_filename() + '.log', 'a') as f:
f.write(str + '\n')
......@@ -13,9 +13,10 @@ class AttrDict(dict):
Example:
from params import AttrDict
params = AttrDict(
data_root = local_config.data_path,
model_name = 'model/inner_halo_types_m{inner_halo_params.margin}_w{inner_halo_params.loss_weights}',
fold_test_substrs = ['/cam_7_7/', '/cam_7_8/', '/cam_7_9/'],
fold_no = 0,
model_name = 'model/inner_halo_types_m{inner_halo_params.margin}_w{inner_halo_params.loss_weights}',
model_params = AttrDict(output_channels=3, enc_type='se_resnet50',
dec_type='unet_scse',
num_filters=16, pretrained=True),
......@@ -23,8 +24,10 @@ class AttrDict(dict):
std = (0.2965651186330059, 0.2801510185680299, 0.2719146471588908),
)
...
base_filename = my_models.ModelFileName(params)
params.save(base_filename)
params.save()
parameters 'data_root' and 'model_name' are required for save() and base_filename() functions.
parameter 'data_root' is not stored and does not influence on hash
'''
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
......@@ -33,16 +36,16 @@ class AttrDict(dict):
assert '.' not in k, "AttrDict: attribute '" + k + "' is invalid ('.' char is not allowed)"
if isinstance(v, dict):
self[k] = AttrDict(v)
def __repr__(self):
return ('{\n' +
'\n'.join([repr(x[0]) + ' : ' + repr(x[1]) + ','
for x in vars(self).items() if not x[0].startswith('__')]) +
for x in vars(self).items() if not x[0].startswith('__') and x[0] != 'data_root']) +
' \n}')
def has(self, name):
'''
checks if self containes attribute with some name, including recursive, i.e. 'b.b1' etc.
checks if self contains attribute with some name, including recursive, i.e. 'b.b1' etc.
'''
names = name.split('.')
dic = self
......@@ -56,14 +59,27 @@ class AttrDict(dict):
'''
hash of dict values, invariant to values order
'''
return hashlib.sha1(json.dumps(self, sort_keys=True).encode()).hexdigest()[:shrink_to]
def save(self, base_fn, verbose = False):
hash_dict = self.copy()
hash_dict.pop('data_root', None)
return hashlib.sha1(json.dumps(hash_dict, sort_keys=True).encode()).hexdigest()[:shrink_to]
def get_model_name(self):
assert self.has('model_name')
return self.model_name.format(**self) + '_' + self.hash()
def get_base_filename(self):
assert self.has('data_root')
return os.path.join(self.data_root, self.get_model_name())
def save(self, base_fn = None, verbose = False, can_overwrite = False):
'''
save to file adding '.param.txt' to name
'''
if base_fn is None:
base_fn = self.get_base_filename()
params_fn = base_fn + '.param.txt'
assert not os.path.exists(params_fn), "Can't save parameters to {}: File exists".format(params_fn)
if not can_overwrite:
assert not os.path.exists(params_fn), "Can't save parameters to {}: File exists".format(params_fn)
with open(params_fn, 'w+') as f:
s = repr(self)
s = s + '\nhash: ' + self.hash()
......@@ -71,24 +87,26 @@ class AttrDict(dict):
if verbose:
print('params: '+ s + '\nsaved to ' + params_fn)
def LoadFromStr(s):
def load_from_str(s, data_root):
s = ''.join(s)
s = s.replace('\n', '')
assert len(s) >= 2
assert s[0][0] == '{'
assert s[-1][-1] == '}'
params = ast.literal_eval(s)
if data_root:
params.data_root = data_root
return AttrDict(params)
def Load(params_fn, verbose = False):
def load(params_fn, data_root, verbose = False):
'''
loads from file, adding '.param.txt' to name
'''
import ast
with open(params_fn + '.param.txt') as f:
with open(params_fn) as f:
s = f.readlines()
assert s[-1].startswith('hash:')
params = AttrDict.LoadFromStr(s[:-1])
params = AttrDict.load_from_str(s[:-1], data_root)
if verbose:
print('params: '+ repr(params) + '\nhash: ' + params.hash() + '\nloaded from ' + params_fn)
return params
......
import torch
import numpy as np
class MarginBaseLoss:
'''
L2-constrained Softmax Loss for Discriminative Face Verification https://arxiv.org/pdf/1703.09507
margin based loss with distance weighted sampling https://arxiv.org/pdf/1706.07567.pdf
'''
ignore_index = -100
def __init__(self, model, classes, device, params):
assert params.data.samples_per_class >= 2
self.model = model
self.device = device
self.params = params
self.classes = sorted(classes)
self.classes_dict = {v: i for i, v in enumerate(self.classes)}
self.lambda_rev = 1/params.distance_weighted_sampling.lambda_
print('classes: ', len(self.classes))
def classes_to_ids(self, y_class, ignore_index = -100):
return torch.tensor([self.classes_dict.get(int(c.item()), ignore_index) for c in y_class]).to(self.device)
def l2_loss(self, net_output, y_class):
pred_class = net_output[0]
class_nos = self.classes_to_ids(y_class, ignore_index=self.ignore_index)
return torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index)(pred_class, class_nos)
def D(self, pred_embeddings, i ,j):
if i == j:
return 0
return torch.dist(pred_embeddings[i], pred_embeddings[j]).item()
def mb_loss(self, net_output, y_class):
pred_embeddings = net_output[1]
loss = 0
n = len(pred_embeddings) # samples in batch
dim = pred_embeddings[0].shape[0] # dimensionality
for i_start in range(0, n, self.params.data.samples_per_class): # start of class block
i_end = i_start + self.params.data.samples_per_class # start of class block
for i in range(i_start, i_end -1):
d_ij = [0 if i==j else self.D(pred_embeddings, i, j) for j in range(n)]
weights = [1/max(self.lambda_rev, pow(d,dim-2)*pow(1-d*d/4, (dim-3)/2)) # https://arxiv.org/pdf/1706.07567.pdf
for id, d in enumerate(d_ij) if id != i] # dont join with itself
weights_same = np.asarray(weights[i_start: i_end-1]) # i-th element already excluded
j = np.random.choice(range(i_start, i_end-1), p = weights_same/np.sum(weights_same) )
if j >= i:
j += 1
# for j in range(i+1, i_end): # positive pair
loss += (self.params.mb_loss.alpha + (d_ij[j] - self.model.mb_loss_beta)).clamp(min=0)
# select neg. pait
weights[i_start: i_end - 1] = [] # i-th element already excluded
weights = np.asarray(weights)
weights = weights/np.sum(weights)
k = np.random.choice(range(0, n - self.params.data.samples_per_class), p = weights)
if k >= i_start:
k += self.params.data.samples_per_class
loss += (self.params.mb_loss.alpha - (d_ij[k] - self.model.mb_loss_beta)).clamp(min=0)
return loss[0] / len(pred_embeddings)
def loss(self, net_output, y_class):
return self.l2_loss(net_output, y_class) + self.mb_loss(net_output, y_class)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment