Commit b1a27f5f authored by IlyaOvodov's avatar IlyaOvodov

create_supervised_trainer, ClrScheduler, fix mb_loss

parent 9ac20fcc
import copy
import math
import torch
import ignite
from ignite.engine import Events
import collections
import time
......@@ -77,7 +79,7 @@ class BestModelBuffer:
self.best_score = None
self.best_epoch = None
def save_if_best(self, engine):
def __call__(self, engine):
assert self.metric_name in engine.state.metrics.keys(), "{} {}".format(self.metric_name, engine.state.metrics.keys())
if self.best_score is None or self.best_score > engine.state.metrics[self.metric_name]:
self.best_score = engine.state.metrics[self.metric_name]
......@@ -85,7 +87,10 @@ class BestModelBuffer:
self.best_epoch = engine.state.epoch
print('model for {}={} dumped'.format(self.metric_name, self.best_score))
if self.save_to_file:
torch.save(self.best_dict, self.params.get_base_filename() + '.t7')
self.save_model()
def save_model(self, suffix = ""):
torch.save(self.best_dict, self.params.get_base_filename() + suffix + '.t7')
def restore(self, model = None):
assert self.best_dict is not None
......@@ -108,7 +113,7 @@ class LogTrainingResults:
for k,v in self.evaluator.state.metrics.items():
engine.state.metrics[key+':'+k] = v
if self.best_model_buffer:
self.best_model_buffer.save_if_best(engine)
self.best_model_buffer(engine)
if event == Events.ITERATION_COMPLETED:
str = "Epoch:{}.{}\t".format(engine.state.epoch, engine.state.iteration)
else:
......@@ -160,3 +165,88 @@ class TensorBoardLogger:
for path, writer in self.writer.all_writers.items():
writer.flush()
class ClrScheduler:
def __init__(self, train_loader, model, optimizer, metric_name, params, minimize = True, engine = None):
self.optimizer = optimizer
self.params = params
self.cycle_index = 0
self.iter_index = 0
self.iterations_per_epoch = len(train_loader)
self.min_lr = params.clr.min_lr
self.max_lr = params.clr.max_lr
self.best_model_buffer = BestModelBuffer(model, metric_name, params, minimize = minimize, save_to_file = False)
if engine:
self.attach(engine)
def attach(self, engine):
engine.add_event_handler(Events.EPOCH_STARTED, self.upd_lr_epoch)
engine.add_event_handler(Events.ITERATION_STARTED, self.upd_lr)
engine.add_event_handler(Events.EPOCH_COMPLETED, self.best_model_buffer)
def upd_lr_epoch(self, engine):
if (self.cycle_index == 0 and self.iter_index == self.params.clr.warmup_epochs*self.iterations_per_epoch
or self.cycle_index > 0 and self.iter_index == self.params.clr.period_epochs*self.iterations_per_epoch):
if self.cycle_index > 0:
self.best_model_buffer.save_model('.'+str(self.cycle_index))
self.best_model_buffer.restore()
self.best_model_buffer.reset()
self.min_lr *= self.params.clr.scale_min_lr
self.max_lr *= self.params.clr.scale_max_lr
self.cycle_index += 1
self.iter_index = 0
def upd_lr(self, engine):
if self.cycle_index == 0:
lr = self.min_lr + (self.max_lr - self.min_lr) * self.iter_index/(self.params.clr.warmup_epochs*self.iterations_per_epoch)
else:
cycle_progress = self.iter_index / (self.params.clr.period_epochs*self.iterations_per_epoch)
lr = self.max_lr + ((self.min_lr - self.max_lr) / 2) * (1 - math.cos(math.pi * cycle_progress))
self.optimizer.param_groups[0]['lr'] = lr
engine.state.metrics['lr'] = self.optimizer.param_groups[0]['lr']
self.iter_index += 1
def create_supervised_trainer(model, optimizer, loss_fn, metrics={},
device=None, non_blocking=False,
prepare_batch=ignite.engine._prepare_batch):
"""
Factory function for creating a trainer for supervised models.
Args:
model (`torch.nn.Module`): the model to train.
optimizer (`torch.optim.Optimizer`): the optimizer to use.
loss_fn (torch.nn loss function): the loss function to use.
device (str, optional): device type specification (default: None).
Applies to both model and batches.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
Note: `engine.state.output` for this engine is the loss of the processed batch.
Returns:
Engine: a trainer engine with supervised update function.
"""
if device:
model.to(device)
def _update(engine, batch):
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
return y_pred, y
engine = ignite.engine.Engine(_update)
for name, metric in metrics.items():
metric.attach(engine, name)
return engine
......@@ -62,12 +62,12 @@ class MarginBaseLoss:
self.false_neg = 0
alpha = self.model.mb_loss_alpha if self.params.mb_loss.train_alpha else self.model.mb_loss_alpha.detach()
alpha2 = self.model.mb_loss_alpha if self.params.mb_loss.train_alpha else 0
with self.timer.watch('time.d_ij'):
assert len(pred_embeddings.shape) == 2, pred_embeddings.shape
norm = (pred_embeddings ** 2).sum(1)
self.d_ij = norm.view(-1, 1) + norm.view(1, -1) - 2.0 * torch.mm(pred_embeddings, torch.transpose(pred_embeddings, 0, 1)) #https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/8
self.d_ij = torch.sqrt(torch.clamp(self.d_ij, min=0.0) + 1.0e-8)
assert len(pred_embeddings.shape) == 2, pred_embeddings.shape
norm = (pred_embeddings ** 2).sum(1)
self.d_ij = norm.view(-1, 1) + norm.view(1, -1) - 2.0 * torch.mm(pred_embeddings, torch.transpose(pred_embeddings, 0, 1)) #https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/8
self.d_ij = torch.sqrt(torch.clamp(self.d_ij, min=0.0) + 1.0e-8)
for i_start in range(0, n, self.params.data.samples_per_class): # start of class block
i_end = i_start + self.params.data.samples_per_class # start of class block
......@@ -80,24 +80,32 @@ class MarginBaseLoss:
weights_same = weights[i_start: i_end] # i-th element already excluded
j = np.random.choice(range(i_start, i_end), p = weights_same/np.sum(weights_same), replace=False)
assert j != i
loss += (alpha + (self.d_ij[i,j] - self.model.mb_loss_beta)).clamp(min=0) - alpha #https://arxiv.org/pdf/1706.07567.pdf
loss += (alpha + (self.d_ij[i,j] - self.model.mb_loss_beta)).clamp(min=0) - alpha2 #https://arxiv.org/pdf/1706.07567.pdf
# select neg. pair
weights = np.delete(weights, np.s_[i_start: i_end], axis=0)
k = np.random.choice(range(0, n - self.params.data.samples_per_class), p = weights/np.sum(weights), replace=False)
with self.timer.watch('time.mb_loss_k'):
k = np.random.choice(range(0, n - self.params.data.samples_per_class), p = weights/np.sum(weights), replace=False)
if k >= i_start:
k += self.params.data.samples_per_class
loss += ((alpha - (self.d_ij[i,k] - self.model.mb_loss_beta)).clamp(min=0) - alpha)*self.params.mb_loss.neg2pos_weight #https://arxiv.org/pdf/1706.07567.pdf
loss += ((alpha - (self.d_ij[i,k] - self.model.mb_loss_beta)).clamp(min=0) - alpha2)*self.params.mb_loss.neg2pos_weight #https://arxiv.org/pdf/1706.07567.pdf
self.mb_loss_val = loss[0] / len(pred_embeddings)
negative = (d > self.model.mb_loss_beta.detach()).float()
positive = (d <= self.model.mb_loss_beta.detach()).float()
fn = sum(negative[i_start: i_end])
self.false_neg += fn
tp = sum(positive[i_start: i_end])
self.true_pos += tp
fp = sum(positive[: i_start]) + sum(positive[i_end:])
self.false_pos += fp
fn = sum(negative[: i_start]) + sum(negative[i_end:])
self.true_neg += fn
with self.timer.watch('time.mb_loss_acc1'):
'''
negative = (d > self.model.mb_loss_beta.detach()).float()
positive = (d <= self.model.mb_loss_beta.detach()).float()
'''
negative = (d > self.model.mb_loss_beta.detach())
positive = (~negative).float()
negative = negative.float()
with self.timer.watch('time.mb_loss_acc2'):
fn = (negative[i_start: i_end]).sum()
self.false_neg += fn
tp = (positive[i_start: i_end]).sum()
self.true_pos += tp
fp = (positive[: i_start]).sum() + (positive[i_end:]).sum()
self.false_pos += fp
fn = (negative[: i_start]).sum() + (negative[i_end:]).sum()
self.true_neg += fn
self.true_pos /= n
self.true_neg /= n
self.false_pos /= n
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment