Commit 6419986d authored by IlyaOvodov's avatar IlyaOvodov

Timer, Trensorboard logger

parent f21f1e8e
import copy import copy
import torch import torch
from ignite.engine import Events from ignite.engine import Events
import collections
import time
import tensorboardX
class IgniteTimes:
class TimerWatch:
def __init__(self, timer, name):
self.name = name
self.timer = timer
def __enter__(self):
self.timer.start(self.name)
return self
def __exit__(self, *args):
self.timer.end(self.name)
return False
def __init__(self, engine, count_iters=False, measured_events={}):
self.clocks = dict()
self.sums = collections.defaultdict(float)
self.counts = collections.defaultdict(int)
for name, (event_engine, start_event, end_event) in measured_events.items():
event_engine.add_event_handler(start_event, self.on_start, name)
event_engine.add_event_handler(end_event, self.on_end, name)
event = Events.ITERATION_COMPLETED if count_iters else Events.EPOCH_COMPLETED
engine.add_event_handler(event, self.on_complete)
def reset_all(self):
self.clocks.clear()
self.sums.clear()
self.counts.clear()
def start(self, name):
assert not name in self.clocks
self.clocks[name] = time.time()
def end(self, name):
assert name in self.clocks
t = time.time() - self.clocks[name]
self.counts[name] += 1
self.sums[name] += t
self.clocks.pop(name)
def watch(self, name):
return self.TimerWatch(self, name)
def on_start(self, engine, name):
self.start(name)
def on_end(self, engine, name):
self.end(name)
def on_complete(self, engine):
for n, v in self.sums.items():
engine.state.metrics[n] = v
self.reset_all()
class BestModelBuffer: class BestModelBuffer:
...@@ -49,7 +108,7 @@ class LogTrainingResults: ...@@ -49,7 +108,7 @@ class LogTrainingResults:
for key,loader in self.loaders_dict.items(): for key,loader in self.loaders_dict.items():
self.evaluator.run(loader) self.evaluator.run(loader)
for k,v in self.evaluator.state.metrics.items(): for k,v in self.evaluator.state.metrics.items():
engine.state.metrics[key+'.'+k] = v engine.state.metrics[key+':'+k] = v
self.best_model_buffer.save_if_best(engine) self.best_model_buffer.save_if_best(engine)
if event == Events.ITERATION_COMPLETED: if event == Events.ITERATION_COMPLETED:
str = "Epoch:{}.{}\t".format(engine.state.epoch, engine.state.iteration) str = "Epoch:{}.{}\t".format(engine.state.epoch, engine.state.iteration)
...@@ -59,3 +118,43 @@ class LogTrainingResults: ...@@ -59,3 +118,43 @@ class LogTrainingResults:
print(str) print(str)
with open(self.params.get_base_filename() + '.log', 'a') as f: with open(self.params.get_base_filename() + '.log', 'a') as f:
f.write(str + '\n') f.write(str + '\n')
class TensorBoardLogger:
SERIES_PLOT_SEPARATOR = ':'
GROUP_PLOT_SEPARATOR = '.'
def __init__(self, trainer_engine, params, count_iters=False, period=1):
log_dir = params.get_base_filename()
self.writer = tensorboardX.SummaryWriter(log_dir=log_dir, flush_secs = 10)
event = Events.ITERATION_COMPLETED if count_iters else Events.EPOCH_COMPLETED
trainer_engine.add_event_handler(event, self.on_event)
self.period = period
self.call_count = 0
trainer_engine.add_event_handler(Events.COMPLETED, self.on_completed)
def on_completed(self, engine):
self.writer.close()
def on_event(self, engine):
'''
engine.state.metrics with name
*|* are interpreted as series(train,val).plot_name(metric)
*|*.* are interpreted as series(train,val).group(metric class).plot_name
'''
self.call_count += 1
if self.call_count % self.period != 0:
return
metrics = collections.defaultdict(dict)
for name, value in engine.state.metrics.items():
name_parts = name.split(self.SERIES_PLOT_SEPARATOR, 1)
if len(name_parts) == 1:
name_parts.append(name_parts[0])
metrics[name_parts[1].replace(self.GROUP_PLOT_SEPARATOR, '/')][name_parts[0]] = value
for n, d in metrics.items():
if len(d) == 1:
for k, v in d.items():
self.writer.add_scalar(n, v, self.call_count)
else:
self.writer.add_scalars(n, d, self.call_count)
import torch import torch
import numpy as np import numpy as np
class DummyTimer:
'''
replacement for IgniteTimer if it is not provided
'''
class TimerWatch:
def __init__(self, timer, name): pass
def __enter__(self): return self
def __exit__(self, *args): return False
def __init__(self): pass
def start(self, name): pass
def end(self, name): pass
def watch(self, name): return self.TimerWatch(self, name)
class MarginBaseLoss: class MarginBaseLoss:
''' '''
L2-constrained Softmax Loss for Discriminative Face Verification https://arxiv.org/pdf/1703.09507 L2-constrained Softmax Loss for Discriminative Face Verification https://arxiv.org/pdf/1703.09507
margin based loss with distance weighted sampling https://arxiv.org/pdf/1706.07567.pdf margin based loss with distance weighted sampling https://arxiv.org/pdf/1706.07567.pdf
''' '''
ignore_index = -100 ignore_index = -100
def __init__(self, model, classes, device, params): def __init__(self, model, classes, device, params, timer = DummyTimer()):
assert params.data.samples_per_class >= 2 assert params.data.samples_per_class >= 2
self.model = model self.model = model
self.device = device self.device = device
...@@ -15,15 +31,20 @@ class MarginBaseLoss: ...@@ -15,15 +31,20 @@ class MarginBaseLoss:
self.classes = sorted(classes) self.classes = sorted(classes)
self.classes_dict = {v: i for i, v in enumerate(self.classes)} self.classes_dict = {v: i for i, v in enumerate(self.classes)}
self.lambda_rev = 1/params.distance_weighted_sampling.lambda_ self.lambda_rev = 1/params.distance_weighted_sampling.lambda_
self.timer = timer
print('classes: ', len(self.classes)) print('classes: ', len(self.classes))
def set_timer(self, timer):
self.timer = timer
def classes_to_ids(self, y_class, ignore_index = -100): def classes_to_ids(self, y_class, ignore_index = -100):
return torch.tensor([self.classes_dict.get(int(c.item()), ignore_index) for c in y_class]).to(self.device) return torch.tensor([self.classes_dict.get(int(c.item()), ignore_index) for c in y_class]).to(self.device)
def l2_loss(self, net_output, y_class): def l2_loss(self, net_output, y_class):
pred_class = net_output[0] with self.timer.watch('time.l2_loss'):
class_nos = self.classes_to_ids(y_class, ignore_index=self.ignore_index) pred_class = net_output[0]
return torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index)(pred_class, class_nos) class_nos = self.classes_to_ids(y_class, ignore_index=self.ignore_index)
return torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index)(pred_class, class_nos)
def D(self, pred_embeddings, i ,j): def D(self, pred_embeddings, i ,j):
if i == j: if i == j:
...@@ -32,31 +53,33 @@ class MarginBaseLoss: ...@@ -32,31 +53,33 @@ class MarginBaseLoss:
def mb_loss(self, net_output, y_class): def mb_loss(self, net_output, y_class):
pred_embeddings = net_output[1] with self.timer.watch('time.mb_loss'):
loss = 0 pred_embeddings = net_output[1]
n = len(pred_embeddings) # samples in batch loss = 0
dim = pred_embeddings[0].shape[0] # dimensionality n = len(pred_embeddings) # samples in batch
for i_start in range(0, n, self.params.data.samples_per_class): # start of class block dim = pred_embeddings[0].shape[0] # dimensionality
i_end = i_start + self.params.data.samples_per_class # start of class block for i_start in range(0, n, self.params.data.samples_per_class): # start of class block
for i in range(i_start, i_end -1): i_end = i_start + self.params.data.samples_per_class # start of class block
d_ij = [0 if i==j else self.D(pred_embeddings, i, j) for j in range(n)] for i in range(i_start, i_end -1):
weights = [1/max(self.lambda_rev, pow(d,dim-2)*pow(1-d*d/4, (dim-3)/2)) # https://arxiv.org/pdf/1706.07567.pdf with self.timer.watch('time.d_ij'):
for id, d in enumerate(d_ij) if id != i] # dont join with itself d_ij = [0 if i==j else self.D(pred_embeddings, i, j) for j in range(n)]
weights_same = np.asarray(weights[i_start: i_end-1]) # i-th element already excluded weights = [1/max(self.lambda_rev, pow(d,dim-2)*pow(1-d*d/4, (dim-3)/2)) # https://arxiv.org/pdf/1706.07567.pdf
j = np.random.choice(range(i_start, i_end-1), p = weights_same/np.sum(weights_same) ) for id, d in enumerate(d_ij) if id != i] # dont join with itself
if j >= i: weights_same = np.asarray(weights[i_start: i_end-1]) # i-th element already excluded
j += 1 j = np.random.choice(range(i_start, i_end-1), p = weights_same/np.sum(weights_same) )
# for j in range(i+1, i_end): # positive pair if j >= i:
loss += (self.params.mb_loss.alpha + (d_ij[j] - self.model.mb_loss_beta)).clamp(min=0) j += 1
# select neg. pait # for j in range(i+1, i_end): # positive pair
weights[i_start: i_end - 1] = [] # i-th element already excluded loss += (self.params.mb_loss.alpha + (d_ij[j] - self.model.mb_loss_beta)).clamp(min=0)
weights = np.asarray(weights) # select neg. pait
weights = weights/np.sum(weights) weights[i_start: i_end - 1] = [] # i-th element already excluded
k = np.random.choice(range(0, n - self.params.data.samples_per_class), p = weights) weights = np.asarray(weights)
if k >= i_start: weights = weights/np.sum(weights)
k += self.params.data.samples_per_class k = np.random.choice(range(0, n - self.params.data.samples_per_class), p = weights)
loss += (self.params.mb_loss.alpha - (d_ij[k] - self.model.mb_loss_beta)).clamp(min=0) if k >= i_start:
return loss[0] / len(pred_embeddings) k += self.params.data.samples_per_class
loss += (self.params.mb_loss.alpha - (d_ij[k] - self.model.mb_loss_beta)).clamp(min=0)
return loss[0] / len(pred_embeddings)
def loss(self, net_output, y_class): def loss(self, net_output, y_class):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment