Commit bb6d7739 authored by IlyaOvodov's avatar IlyaOvodov

v2 (works but not better then baseline)

parent 96ae71c2
......@@ -4,7 +4,7 @@ import ignite
def create_adaptive_supervised_trainer(model, optimizer, loss_fn, metrics={},
device=None, non_blocking=False,
prepare_batch=ignite.engine._prepare_batch, lr_scale = 2):
prepare_batch=ignite.engine._prepare_batch, lr_scale = 1.1, warmup_iters = 50):
"""
Factory function for creating a trainer for supervised models.
......@@ -35,98 +35,37 @@ def create_adaptive_supervised_trainer(model, optimizer, loss_fn, metrics={},
d_p = p.grad.data
p.data.add_(-group['lr'] * (new_k - prev_k), d_p)
def _update1(engine, batch):
model.train()
optimizer.zero_grad()
def _update(engine, batch):
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
prev_k = 1
new_k = lr_scale
multiply_k = True
print('epoch', engine.state.epoch,'iter',engine.state.iteration, 'base', optimizer.param_groups[0]['lr'], 1, loss)
if engine.state.epoch <= 1:
return y_pred, y
model.train()
with torch.no_grad():
while True:
correct_model(prev_k, new_k)
y_pred2 = model(x)
loss2 = loss_fn(y_pred2, y)
print('new ', optimizer.param_groups[0]['lr'], new_k, loss2)
if loss2>=loss:
correct_model(new_k, prev_k)
if multiply_k and prev_k == 1:
multiply_k = False
new_k = prev_k/lr_scale
else:
break
else:
y_pred = y_pred2
loss = loss2
if engine.state.iteration > warmup_iters:
prev_k = 1
loss = None
new_ks_list = (1/lr_scale, lr_scale,)
with torch.no_grad():
for new_k in new_ks_list:
correct_model(prev_k, new_k)
y_pred = model(x)
loss0 = loss
loss = loss_fn(y_pred, y)
prev_k = new_k
if multiply_k:
new_k *= lr_scale
else:
new_k /= lr_scale
for group in optimizer.param_groups:
group['lr'] *= prev_k
print('fin ', optimizer.param_groups[0]['lr'], loss)
print('iter\t{}.{}'.format(engine.state.epoch, engine.state.iteration), 'lr',
optimizer.param_groups[0]['lr'], '*', new_k, 'loss', loss.item())
if loss0 < loss or (loss0 == loss and engine.state.iteration % 2):
new_k = new_ks_list[0]
correct_model(prev_k, new_k)
for group in optimizer.param_groups:
group['lr'] *= new_k
return y_pred, y
def _update(engine, batch):
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
prev_k = 1
new_k = lr_scale
multiply_k = True
print('epoch', engine.state.epoch,'iter',engine.state.iteration, 'base', optimizer.param_groups[0]['lr'], 1, loss.item())
if engine.state.epoch <= 1:
return y_pred, y
with torch.no_grad():
while True:
correct_model(prev_k, new_k)
y_pred2 = model(x)
loss2 = loss_fn(y_pred2, y)
print('new ', optimizer.param_groups[0]['lr'], new_k, loss2.item())
if loss2>=loss:
correct_model(new_k, prev_k)
if multiply_k and prev_k == 1:
multiply_k = False
new_k = prev_k/lr_scale
else:
break
else:
y_pred = y_pred2
loss = loss2
prev_k = new_k
break
'''
if multiply_k:
new_k *= lr_scale
else:
new_k /= lr_scale
'''
for group in optimizer.param_groups:
group['lr'] *= prev_k
print('fin ', optimizer.param_groups[0]['lr'], loss)
print('iter\t{}.{}'.format(engine.state.epoch, engine.state.iteration), 'lr', optimizer.param_groups[0]['lr'], 'loss', loss.item())
return y_pred, y
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment