Commit b1a27f5f authored by IlyaOvodov's avatar IlyaOvodov

create_supervised_trainer, ClrScheduler, fix mb_loss

parent 9ac20fcc
import copy import copy
import math
import torch import torch
import ignite
from ignite.engine import Events from ignite.engine import Events
import collections import collections
import time import time
...@@ -77,7 +79,7 @@ class BestModelBuffer: ...@@ -77,7 +79,7 @@ class BestModelBuffer:
self.best_score = None self.best_score = None
self.best_epoch = None self.best_epoch = None
def save_if_best(self, engine): def __call__(self, engine):
assert self.metric_name in engine.state.metrics.keys(), "{} {}".format(self.metric_name, engine.state.metrics.keys()) assert self.metric_name in engine.state.metrics.keys(), "{} {}".format(self.metric_name, engine.state.metrics.keys())
if self.best_score is None or self.best_score > engine.state.metrics[self.metric_name]: if self.best_score is None or self.best_score > engine.state.metrics[self.metric_name]:
self.best_score = engine.state.metrics[self.metric_name] self.best_score = engine.state.metrics[self.metric_name]
...@@ -85,7 +87,10 @@ class BestModelBuffer: ...@@ -85,7 +87,10 @@ class BestModelBuffer:
self.best_epoch = engine.state.epoch self.best_epoch = engine.state.epoch
print('model for {}={} dumped'.format(self.metric_name, self.best_score)) print('model for {}={} dumped'.format(self.metric_name, self.best_score))
if self.save_to_file: if self.save_to_file:
torch.save(self.best_dict, self.params.get_base_filename() + '.t7') self.save_model()
def save_model(self, suffix = ""):
torch.save(self.best_dict, self.params.get_base_filename() + suffix + '.t7')
def restore(self, model = None): def restore(self, model = None):
assert self.best_dict is not None assert self.best_dict is not None
...@@ -108,7 +113,7 @@ class LogTrainingResults: ...@@ -108,7 +113,7 @@ class LogTrainingResults:
for k,v in self.evaluator.state.metrics.items(): for k,v in self.evaluator.state.metrics.items():
engine.state.metrics[key+':'+k] = v engine.state.metrics[key+':'+k] = v
if self.best_model_buffer: if self.best_model_buffer:
self.best_model_buffer.save_if_best(engine) self.best_model_buffer(engine)
if event == Events.ITERATION_COMPLETED: if event == Events.ITERATION_COMPLETED:
str = "Epoch:{}.{}\t".format(engine.state.epoch, engine.state.iteration) str = "Epoch:{}.{}\t".format(engine.state.epoch, engine.state.iteration)
else: else:
...@@ -160,3 +165,88 @@ class TensorBoardLogger: ...@@ -160,3 +165,88 @@ class TensorBoardLogger:
for path, writer in self.writer.all_writers.items(): for path, writer in self.writer.all_writers.items():
writer.flush() writer.flush()
class ClrScheduler:
def __init__(self, train_loader, model, optimizer, metric_name, params, minimize = True, engine = None):
self.optimizer = optimizer
self.params = params
self.cycle_index = 0
self.iter_index = 0
self.iterations_per_epoch = len(train_loader)
self.min_lr = params.clr.min_lr
self.max_lr = params.clr.max_lr
self.best_model_buffer = BestModelBuffer(model, metric_name, params, minimize = minimize, save_to_file = False)
if engine:
self.attach(engine)
def attach(self, engine):
engine.add_event_handler(Events.EPOCH_STARTED, self.upd_lr_epoch)
engine.add_event_handler(Events.ITERATION_STARTED, self.upd_lr)
engine.add_event_handler(Events.EPOCH_COMPLETED, self.best_model_buffer)
def upd_lr_epoch(self, engine):
if (self.cycle_index == 0 and self.iter_index == self.params.clr.warmup_epochs*self.iterations_per_epoch
or self.cycle_index > 0 and self.iter_index == self.params.clr.period_epochs*self.iterations_per_epoch):
if self.cycle_index > 0:
self.best_model_buffer.save_model('.'+str(self.cycle_index))
self.best_model_buffer.restore()
self.best_model_buffer.reset()
self.min_lr *= self.params.clr.scale_min_lr
self.max_lr *= self.params.clr.scale_max_lr
self.cycle_index += 1
self.iter_index = 0
def upd_lr(self, engine):
if self.cycle_index == 0:
lr = self.min_lr + (self.max_lr - self.min_lr) * self.iter_index/(self.params.clr.warmup_epochs*self.iterations_per_epoch)
else:
cycle_progress = self.iter_index / (self.params.clr.period_epochs*self.iterations_per_epoch)
lr = self.max_lr + ((self.min_lr - self.max_lr) / 2) * (1 - math.cos(math.pi * cycle_progress))
self.optimizer.param_groups[0]['lr'] = lr
engine.state.metrics['lr'] = self.optimizer.param_groups[0]['lr']
self.iter_index += 1
def create_supervised_trainer(model, optimizer, loss_fn, metrics={},
device=None, non_blocking=False,
prepare_batch=ignite.engine._prepare_batch):
"""
Factory function for creating a trainer for supervised models.
Args:
model (`torch.nn.Module`): the model to train.
optimizer (`torch.optim.Optimizer`): the optimizer to use.
loss_fn (torch.nn loss function): the loss function to use.
device (str, optional): device type specification (default: None).
Applies to both model and batches.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
Note: `engine.state.output` for this engine is the loss of the processed batch.
Returns:
Engine: a trainer engine with supervised update function.
"""
if device:
model.to(device)
def _update(engine, batch):
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
return y_pred, y
engine = ignite.engine.Engine(_update)
for name, metric in metrics.items():
metric.attach(engine, name)
return engine
...@@ -62,12 +62,12 @@ class MarginBaseLoss: ...@@ -62,12 +62,12 @@ class MarginBaseLoss:
self.false_neg = 0 self.false_neg = 0
alpha = self.model.mb_loss_alpha if self.params.mb_loss.train_alpha else self.model.mb_loss_alpha.detach() alpha = self.model.mb_loss_alpha if self.params.mb_loss.train_alpha else self.model.mb_loss_alpha.detach()
alpha2 = self.model.mb_loss_alpha if self.params.mb_loss.train_alpha else 0
with self.timer.watch('time.d_ij'): assert len(pred_embeddings.shape) == 2, pred_embeddings.shape
assert len(pred_embeddings.shape) == 2, pred_embeddings.shape norm = (pred_embeddings ** 2).sum(1)
norm = (pred_embeddings ** 2).sum(1) self.d_ij = norm.view(-1, 1) + norm.view(1, -1) - 2.0 * torch.mm(pred_embeddings, torch.transpose(pred_embeddings, 0, 1)) #https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/8
self.d_ij = norm.view(-1, 1) + norm.view(1, -1) - 2.0 * torch.mm(pred_embeddings, torch.transpose(pred_embeddings, 0, 1)) #https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/8 self.d_ij = torch.sqrt(torch.clamp(self.d_ij, min=0.0) + 1.0e-8)
self.d_ij = torch.sqrt(torch.clamp(self.d_ij, min=0.0) + 1.0e-8)
for i_start in range(0, n, self.params.data.samples_per_class): # start of class block for i_start in range(0, n, self.params.data.samples_per_class): # start of class block
i_end = i_start + self.params.data.samples_per_class # start of class block i_end = i_start + self.params.data.samples_per_class # start of class block
...@@ -80,24 +80,32 @@ class MarginBaseLoss: ...@@ -80,24 +80,32 @@ class MarginBaseLoss:
weights_same = weights[i_start: i_end] # i-th element already excluded weights_same = weights[i_start: i_end] # i-th element already excluded
j = np.random.choice(range(i_start, i_end), p = weights_same/np.sum(weights_same), replace=False) j = np.random.choice(range(i_start, i_end), p = weights_same/np.sum(weights_same), replace=False)
assert j != i assert j != i
loss += (alpha + (self.d_ij[i,j] - self.model.mb_loss_beta)).clamp(min=0) - alpha #https://arxiv.org/pdf/1706.07567.pdf loss += (alpha + (self.d_ij[i,j] - self.model.mb_loss_beta)).clamp(min=0) - alpha2 #https://arxiv.org/pdf/1706.07567.pdf
# select neg. pair # select neg. pair
weights = np.delete(weights, np.s_[i_start: i_end], axis=0) weights = np.delete(weights, np.s_[i_start: i_end], axis=0)
k = np.random.choice(range(0, n - self.params.data.samples_per_class), p = weights/np.sum(weights), replace=False) with self.timer.watch('time.mb_loss_k'):
k = np.random.choice(range(0, n - self.params.data.samples_per_class), p = weights/np.sum(weights), replace=False)
if k >= i_start: if k >= i_start:
k += self.params.data.samples_per_class k += self.params.data.samples_per_class
loss += ((alpha - (self.d_ij[i,k] - self.model.mb_loss_beta)).clamp(min=0) - alpha)*self.params.mb_loss.neg2pos_weight #https://arxiv.org/pdf/1706.07567.pdf loss += ((alpha - (self.d_ij[i,k] - self.model.mb_loss_beta)).clamp(min=0) - alpha2)*self.params.mb_loss.neg2pos_weight #https://arxiv.org/pdf/1706.07567.pdf
self.mb_loss_val = loss[0] / len(pred_embeddings) self.mb_loss_val = loss[0] / len(pred_embeddings)
negative = (d > self.model.mb_loss_beta.detach()).float() with self.timer.watch('time.mb_loss_acc1'):
positive = (d <= self.model.mb_loss_beta.detach()).float() '''
fn = sum(negative[i_start: i_end]) negative = (d > self.model.mb_loss_beta.detach()).float()
self.false_neg += fn positive = (d <= self.model.mb_loss_beta.detach()).float()
tp = sum(positive[i_start: i_end]) '''
self.true_pos += tp negative = (d > self.model.mb_loss_beta.detach())
fp = sum(positive[: i_start]) + sum(positive[i_end:]) positive = (~negative).float()
self.false_pos += fp negative = negative.float()
fn = sum(negative[: i_start]) + sum(negative[i_end:]) with self.timer.watch('time.mb_loss_acc2'):
self.true_neg += fn fn = (negative[i_start: i_end]).sum()
self.false_neg += fn
tp = (positive[i_start: i_end]).sum()
self.true_pos += tp
fp = (positive[: i_start]).sum() + (positive[i_end:]).sum()
self.false_pos += fp
fn = (negative[: i_start]).sum() + (negative[i_end:]).sum()
self.true_neg += fn
self.true_pos /= n self.true_pos /= n
self.true_neg /= n self.true_neg /= n
self.false_pos /= n self.false_pos /= n
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment