Commit 68d5cadd authored by IlyaOvodov's avatar IlyaOvodov

Log duty cycle

parent 4703916b
......@@ -110,7 +110,15 @@ class BestModelBuffer:
class LogTrainingResults:
def __init__(self, evaluator, loaders_dict, best_model_buffer, params, rel_dir = "", filename = None):
def __init__(self, evaluator, loaders_dict, best_model_buffer, params, rel_dir = "", filename = None, duty_cycles = 1):
'''
evaluates metrics using evaluator and data loaders, adds them to caller engine metrics, logs and keeps track on best model
:param evaluator: ignite engine to evaluate metrics
:param loaders_dict: dict: {'name': data_loader}. Metrics are evaluated for all data_loaders, resulting metric named 'name:original_metric_name'
:param best_model_buffer: optrional, tracks model with best mertric
:param params: global params to get params.get_base_filename()
:param duty_cycles: int or dict: {event: int}: if set, enables functioning only every n-th call fo each event
'''
self.evaluator = evaluator
self.loaders_dict = loaders_dict
self.best_model_buffer = best_model_buffer
......@@ -119,7 +127,18 @@ class LogTrainingResults:
filename = self.params.get_model_name() + ".log"
self.file_name = os.path.join(self.params.get_base_filename(), rel_dir, filename)
self.calls_count = collections.defaultdict(int)
assert isinstance(duty_cycles, (int, dict,))
self.duty_cycles = duty_cycles
def __call__(self, engine, event):
self.calls_count[event] += 1
if isinstance(self.duty_cycles, int):
duty_cycles = self.duty_cycles
elif isinstance(self.duty_cycles, dict):
duty_cycles = self.duty_cycles[event]
if self.calls_count[event] % duty_cycles != 0:
return
for key,loader in self.loaders_dict.items():
self.evaluator.run(loader)
for k,v in self.evaluator.state.metrics.items():
......@@ -212,8 +231,8 @@ class ClrScheduler:
self.best_model_buffer.save_model(rel_dir = 'models', filename = 'clr.{:03}'.format(self.cycle_index))
self.best_model_buffer.restore()
self.best_model_buffer.reset()
self.min_lr *= self.params.clr.scale_min_lr
self.max_lr *= self.params.clr.scale_max_lr
self.min_lr *= self.params.clr.get('scale_min_lr', 1)
self.max_lr *= self.params.clr.get('scale_max_lr', 1)
self.cycle_index += 1
self.iter_index = 0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment