Commit b97a30ed authored by IlyaOvodov's avatar IlyaOvodov

adaptiveLR v3

parent a2fb35df
......@@ -4,10 +4,10 @@ import ignite
def create_adaptive_supervised_trainer(model, optimizer, loss_fn, metrics={},
device=None, non_blocking=False,
prepare_batch=ignite.engine._prepare_batch, lr_scale = 1.1, warmup_iters = 50):
prepare_batch=ignite.engine._prepare_batch, lr_scale = 1.1, warmup_iters = 50, ls_mult = 3):
"""
Factory function for creating a trainer for supervised models.
l
Args:
model (`torch.nn.Module`): the model to train.
optimizer (`torch.optim.Optimizer`): the optimizer to use.
......@@ -41,31 +41,40 @@ def create_adaptive_supervised_trainer(model, optimizer, loss_fn, metrics={},
model.train()
if engine.state.iteration > warmup_iters:
prev_k = 1
loss = None
new_ks_list = (1/lr_scale, lr_scale,)
with torch.no_grad():
for new_k in new_ks_list:
correct_model(prev_k, new_k)
y_pred = model(x)
loss0 = loss
loss = loss_fn(y_pred, y)
prev_k = new_k
print('iter\t{}.{}'.format(engine.state.epoch, engine.state.iteration), 'lr',
optimizer.param_groups[0]['lr'], '*', new_k, 'loss', loss.item())
if loss0 < loss or (loss0 == loss and engine.state.iteration % 2):
new_k = new_ks_list[0]
correct_model(prev_k, new_k)
for group in optimizer.param_groups:
group['lr'] *= new_k
if engine.state.iteration % 2:
new_k = 1 / lr_scale
else:
new_k = lr_scale
for group in optimizer.param_groups:
group['lr'] *= new_k
else:
prev_k = new_k = 1
if engine.state.iteration > 1:
optimizer.step()
if engine.state.iteration > warmup_iters:
with torch.no_grad():
y_pred = model(x)
loss0 = loss_fn(y_pred, y)
print('iter\t{}.{}'.format(engine.state.epoch, engine.state.iteration), 'lr * {:5.3}'.format(new_k), 'loss', loss0.item())
prev_k = new_k
new_k = 1/new_k
correct_model(prev_k, new_k)
optimizer.zero_grad()
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
print('iter\t{}.{}'.format(engine.state.epoch, engine.state.iteration), 'lr', optimizer.param_groups[0]['lr'], 'loss', loss.item())
if engine.state.iteration > warmup_iters:
with torch.no_grad():
print('iter\t{}.{}'.format(engine.state.epoch, engine.state.iteration), 'lr * {:5.3}'.format(new_k), 'loss', loss.item())
if loss < loss0 or (loss == loss0 and engine.state.iteration % 2):
for group in optimizer.param_groups:
group['lr'] *= new_k/prev_k
print('iter\t{}.{}'.format(engine.state.epoch, engine.state.iteration), 'lr', optimizer.param_groups[0]['lr'], 'loss', loss.item())
return y_pred, y
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment