Commit 75b53916 authored by Ilya Ovodov's avatar Ilya Ovodov

CLR epochs -> iters

parent d753605f
...@@ -178,7 +178,6 @@ class TensorBoardLogger: ...@@ -178,7 +178,6 @@ class TensorBoardLogger:
event = Events.ITERATION_COMPLETED if count_iters else Events.EPOCH_COMPLETED event = Events.ITERATION_COMPLETED if count_iters else Events.EPOCH_COMPLETED
trainer_engine.add_event_handler(event, self.on_event) trainer_engine.add_event_handler(event, self.on_event)
self.period = period self.period = period
self.call_count = 0
trainer_engine.add_event_handler(Events.COMPLETED, self.on_completed) trainer_engine.add_event_handler(Events.COMPLETED, self.on_completed)
def start_server(self, port, start_it = False): def start_server(self, port, start_it = False):
...@@ -200,8 +199,7 @@ class TensorBoardLogger: ...@@ -200,8 +199,7 @@ class TensorBoardLogger:
*|* are interpreted as series(train,val).plot_name(metric) *|* are interpreted as series(train,val).plot_name(metric)
*|*.* are interpreted as series(train,val).group(metric class).plot_name *|*.* are interpreted as series(train,val).group(metric class).plot_name
''' '''
self.call_count += 1 if engine.state.iteration % self.period != 0:
if self.call_count % self.period != 0:
return return
metrics = collections.defaultdict(dict) metrics = collections.defaultdict(dict)
...@@ -213,9 +211,9 @@ class TensorBoardLogger: ...@@ -213,9 +211,9 @@ class TensorBoardLogger:
for n, d in metrics.items(): for n, d in metrics.items():
if len(d) == 1: if len(d) == 1:
for k, v in d.items(): for k, v in d.items():
self.writer.add_scalar(n, v, self.call_count) self.writer.add_scalar(n, v, engine.state.iteration)
else: else:
self.writer.add_scalars(n, d, self.call_count) self.writer.add_scalars(n, d, engine.state.iteration)
for path, writer in self.writer.all_writers.items(): for path, writer in self.writer.all_writers.items():
writer.flush() writer.flush()
...@@ -224,9 +222,9 @@ class ClrScheduler: ...@@ -224,9 +222,9 @@ class ClrScheduler:
def __init__(self, train_loader, model, optimizer, metric_name, params, minimize = True, engine = None): def __init__(self, train_loader, model, optimizer, metric_name, params, minimize = True, engine = None):
self.optimizer = optimizer self.optimizer = optimizer
self.params = params self.params = params
self.cycle_index = 0 self.cycle_index = 0 # 0 - warmup, 1+ - cycle
self.iter_index = 0 self.inner_index = 0 # index starting from cycle begin
self.iterations_per_epoch = len(train_loader) #self.iterations_per_epoch = len(train_loader)
self.min_lr = params.clr.min_lr self.min_lr = params.clr.min_lr
self.max_lr = params.clr.max_lr self.max_lr = params.clr.max_lr
self.best_model_buffer = BestModelBuffer(model, metric_name, params, minimize = minimize, self.best_model_buffer = BestModelBuffer(model, metric_name, params, minimize = minimize,
...@@ -245,13 +243,12 @@ class ClrScheduler: ...@@ -245,13 +243,12 @@ class ClrScheduler:
self.__dict__.update(state_dict) self.__dict__.update(state_dict)
def attach(self, engine): def attach(self, engine):
engine.add_event_handler(Events.EPOCH_STARTED, self.upd_lr_epoch)
engine.add_event_handler(Events.ITERATION_STARTED, self.upd_lr) engine.add_event_handler(Events.ITERATION_STARTED, self.upd_lr)
engine.add_event_handler(Events.EPOCH_COMPLETED, self.best_model_buffer) engine.add_event_handler(Events.EPOCH_COMPLETED, self.best_model_buffer)
def upd_lr_epoch(self, engine): def upd_lr(self, engine):
if (self.cycle_index == 0 and self.iter_index == self.params.clr.warmup_epochs*self.iterations_per_epoch if (self.cycle_index == 0 and self.inner_index == self.params.clr.warmup_iters
or self.cycle_index > 0 and self.iter_index == self.params.clr.period_epochs*self.iterations_per_epoch): or self.cycle_index > 0 and self.inner_index == self.params.clr.period_iters):
if self.cycle_index > 0: if self.cycle_index > 0:
self.best_model_buffer.save_model(rel_dir = 'models', filename = 'clr.{:03}.t7'.format(self.cycle_index)) self.best_model_buffer.save_model(rel_dir = 'models', filename = 'clr.{:03}.t7'.format(self.cycle_index))
self.best_model_buffer.restore() self.best_model_buffer.restore()
...@@ -259,17 +256,15 @@ class ClrScheduler: ...@@ -259,17 +256,15 @@ class ClrScheduler:
self.min_lr *= self.params.clr.get('scale_min_lr', 1) self.min_lr *= self.params.clr.get('scale_min_lr', 1)
self.max_lr *= self.params.clr.get('scale_max_lr', 1) self.max_lr *= self.params.clr.get('scale_max_lr', 1)
self.cycle_index += 1 self.cycle_index += 1
self.iter_index = 0 self.inner_index = 0
def upd_lr(self, engine):
if self.cycle_index == 0: if self.cycle_index == 0:
lr = self.min_lr + (self.max_lr - self.min_lr) * self.iter_index/(self.params.clr.warmup_epochs*self.iterations_per_epoch) lr = self.min_lr + (self.max_lr - self.min_lr) * self.inner_index / (self.params.clr.warmup_iters)
else: else:
cycle_progress = self.iter_index / (self.params.clr.period_epochs*self.iterations_per_epoch) cycle_progress = self.inner_index / (self.params.clr.period_iters)
lr = self.max_lr + ((self.min_lr - self.max_lr) / 2) * (1 - math.cos(math.pi * cycle_progress)) lr = self.max_lr + ((self.min_lr - self.max_lr) / 2) * (1 - math.cos(math.pi * cycle_progress))
self.optimizer.param_groups[0]['lr'] = lr self.optimizer.param_groups[0]['lr'] = lr
engine.state.metrics['lr'] = self.optimizer.param_groups[0]['lr'] engine.state.metrics['lr'] = self.optimizer.param_groups[0]['lr']
self.iter_index += 1 self.inner_index += 1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment