Commit 1237f3d8 authored by IlyaOvodov's avatar IlyaOvodov

To new ignite with my update

parent 0b5d797d
......@@ -271,21 +271,12 @@ def create_supervised_trainer(model, optimizer, loss_fn, metrics={},
Returns:
Engine: a trainer engine with supervised update function.
"""
if device:
model.to(device)
engine = ignite.engine.create_supervised_trainer(model, optimizer, loss_fn, device, non_blocking, prepare_batch,
output_transform = lambda x, y, y_pred, loss: (y_pred, y,))
def _update(engine, batch):
@engine.on(Events.ITERATION_STARTED)
def reset_output(engine):
engine.state.output = None
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
return y_pred, y
engine = ignite.engine.Engine(_update)
for name, metric in metrics.items():
metric.attach(engine, 'train:' + name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment