Commit 6419986d authored by IlyaOvodov's avatar IlyaOvodov

Timer, Trensorboard logger

parent f21f1e8e
import copy
import torch
from ignite.engine import Events
import collections
import time
import tensorboardX
class IgniteTimes:
class TimerWatch:
def __init__(self, timer, name):
self.name = name
self.timer = timer
def __enter__(self):
self.timer.start(self.name)
return self
def __exit__(self, *args):
self.timer.end(self.name)
return False
def __init__(self, engine, count_iters=False, measured_events={}):
self.clocks = dict()
self.sums = collections.defaultdict(float)
self.counts = collections.defaultdict(int)
for name, (event_engine, start_event, end_event) in measured_events.items():
event_engine.add_event_handler(start_event, self.on_start, name)
event_engine.add_event_handler(end_event, self.on_end, name)
event = Events.ITERATION_COMPLETED if count_iters else Events.EPOCH_COMPLETED
engine.add_event_handler(event, self.on_complete)
def reset_all(self):
self.clocks.clear()
self.sums.clear()
self.counts.clear()
def start(self, name):
assert not name in self.clocks
self.clocks[name] = time.time()
def end(self, name):
assert name in self.clocks
t = time.time() - self.clocks[name]
self.counts[name] += 1
self.sums[name] += t
self.clocks.pop(name)
def watch(self, name):
return self.TimerWatch(self, name)
def on_start(self, engine, name):
self.start(name)
def on_end(self, engine, name):
self.end(name)
def on_complete(self, engine):
for n, v in self.sums.items():
engine.state.metrics[n] = v
self.reset_all()
class BestModelBuffer:
......@@ -49,7 +108,7 @@ class LogTrainingResults:
for key,loader in self.loaders_dict.items():
self.evaluator.run(loader)
for k,v in self.evaluator.state.metrics.items():
engine.state.metrics[key+'.'+k] = v
engine.state.metrics[key+':'+k] = v
self.best_model_buffer.save_if_best(engine)
if event == Events.ITERATION_COMPLETED:
str = "Epoch:{}.{}\t".format(engine.state.epoch, engine.state.iteration)
......@@ -59,3 +118,43 @@ class LogTrainingResults:
print(str)
with open(self.params.get_base_filename() + '.log', 'a') as f:
f.write(str + '\n')
class TensorBoardLogger:
SERIES_PLOT_SEPARATOR = ':'
GROUP_PLOT_SEPARATOR = '.'
def __init__(self, trainer_engine, params, count_iters=False, period=1):
log_dir = params.get_base_filename()
self.writer = tensorboardX.SummaryWriter(log_dir=log_dir, flush_secs = 10)
event = Events.ITERATION_COMPLETED if count_iters else Events.EPOCH_COMPLETED
trainer_engine.add_event_handler(event, self.on_event)
self.period = period
self.call_count = 0
trainer_engine.add_event_handler(Events.COMPLETED, self.on_completed)
def on_completed(self, engine):
self.writer.close()
def on_event(self, engine):
'''
engine.state.metrics with name
*|* are interpreted as series(train,val).plot_name(metric)
*|*.* are interpreted as series(train,val).group(metric class).plot_name
'''
self.call_count += 1
if self.call_count % self.period != 0:
return
metrics = collections.defaultdict(dict)
for name, value in engine.state.metrics.items():
name_parts = name.split(self.SERIES_PLOT_SEPARATOR, 1)
if len(name_parts) == 1:
name_parts.append(name_parts[0])
metrics[name_parts[1].replace(self.GROUP_PLOT_SEPARATOR, '/')][name_parts[0]] = value
for n, d in metrics.items():
if len(d) == 1:
for k, v in d.items():
self.writer.add_scalar(n, v, self.call_count)
else:
self.writer.add_scalars(n, d, self.call_count)
import torch
import numpy as np
class DummyTimer:
'''
replacement for IgniteTimer if it is not provided
'''
class TimerWatch:
def __init__(self, timer, name): pass
def __enter__(self): return self
def __exit__(self, *args): return False
def __init__(self): pass
def start(self, name): pass
def end(self, name): pass
def watch(self, name): return self.TimerWatch(self, name)
class MarginBaseLoss:
'''
L2-constrained Softmax Loss for Discriminative Face Verification https://arxiv.org/pdf/1703.09507
margin based loss with distance weighted sampling https://arxiv.org/pdf/1706.07567.pdf
'''
ignore_index = -100
def __init__(self, model, classes, device, params):
def __init__(self, model, classes, device, params, timer = DummyTimer()):
assert params.data.samples_per_class >= 2
self.model = model
self.device = device
......@@ -15,15 +31,20 @@ class MarginBaseLoss:
self.classes = sorted(classes)
self.classes_dict = {v: i for i, v in enumerate(self.classes)}
self.lambda_rev = 1/params.distance_weighted_sampling.lambda_
self.timer = timer
print('classes: ', len(self.classes))
def set_timer(self, timer):
self.timer = timer
def classes_to_ids(self, y_class, ignore_index = -100):
return torch.tensor([self.classes_dict.get(int(c.item()), ignore_index) for c in y_class]).to(self.device)
def l2_loss(self, net_output, y_class):
pred_class = net_output[0]
class_nos = self.classes_to_ids(y_class, ignore_index=self.ignore_index)
return torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index)(pred_class, class_nos)
with self.timer.watch('time.l2_loss'):
pred_class = net_output[0]
class_nos = self.classes_to_ids(y_class, ignore_index=self.ignore_index)
return torch.nn.CrossEntropyLoss(ignore_index=self.ignore_index)(pred_class, class_nos)
def D(self, pred_embeddings, i ,j):
if i == j:
......@@ -32,31 +53,33 @@ class MarginBaseLoss:
def mb_loss(self, net_output, y_class):
pred_embeddings = net_output[1]
loss = 0
n = len(pred_embeddings) # samples in batch
dim = pred_embeddings[0].shape[0] # dimensionality
for i_start in range(0, n, self.params.data.samples_per_class): # start of class block
i_end = i_start + self.params.data.samples_per_class # start of class block
for i in range(i_start, i_end -1):
d_ij = [0 if i==j else self.D(pred_embeddings, i, j) for j in range(n)]
weights = [1/max(self.lambda_rev, pow(d,dim-2)*pow(1-d*d/4, (dim-3)/2)) # https://arxiv.org/pdf/1706.07567.pdf
for id, d in enumerate(d_ij) if id != i] # dont join with itself
weights_same = np.asarray(weights[i_start: i_end-1]) # i-th element already excluded
j = np.random.choice(range(i_start, i_end-1), p = weights_same/np.sum(weights_same) )
if j >= i:
j += 1
# for j in range(i+1, i_end): # positive pair
loss += (self.params.mb_loss.alpha + (d_ij[j] - self.model.mb_loss_beta)).clamp(min=0)
# select neg. pait
weights[i_start: i_end - 1] = [] # i-th element already excluded
weights = np.asarray(weights)
weights = weights/np.sum(weights)
k = np.random.choice(range(0, n - self.params.data.samples_per_class), p = weights)
if k >= i_start:
k += self.params.data.samples_per_class
loss += (self.params.mb_loss.alpha - (d_ij[k] - self.model.mb_loss_beta)).clamp(min=0)
return loss[0] / len(pred_embeddings)
with self.timer.watch('time.mb_loss'):
pred_embeddings = net_output[1]
loss = 0
n = len(pred_embeddings) # samples in batch
dim = pred_embeddings[0].shape[0] # dimensionality
for i_start in range(0, n, self.params.data.samples_per_class): # start of class block
i_end = i_start + self.params.data.samples_per_class # start of class block
for i in range(i_start, i_end -1):
with self.timer.watch('time.d_ij'):
d_ij = [0 if i==j else self.D(pred_embeddings, i, j) for j in range(n)]
weights = [1/max(self.lambda_rev, pow(d,dim-2)*pow(1-d*d/4, (dim-3)/2)) # https://arxiv.org/pdf/1706.07567.pdf
for id, d in enumerate(d_ij) if id != i] # dont join with itself
weights_same = np.asarray(weights[i_start: i_end-1]) # i-th element already excluded
j = np.random.choice(range(i_start, i_end-1), p = weights_same/np.sum(weights_same) )
if j >= i:
j += 1
# for j in range(i+1, i_end): # positive pair
loss += (self.params.mb_loss.alpha + (d_ij[j] - self.model.mb_loss_beta)).clamp(min=0)
# select neg. pait
weights[i_start: i_end - 1] = [] # i-th element already excluded
weights = np.asarray(weights)
weights = weights/np.sum(weights)
k = np.random.choice(range(0, n - self.params.data.samples_per_class), p = weights)
if k >= i_start:
k += self.params.data.samples_per_class
loss += (self.params.mb_loss.alpha - (d_ij[k] - self.model.mb_loss_beta)).clamp(min=0)
return loss[0] / len(pred_embeddings)
def loss(self, net_output, y_class):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment