Commit 1237f3d8 authored by IlyaOvodov's avatar IlyaOvodov

To new ignite with my update

parent 0b5d797d
...@@ -271,21 +271,12 @@ def create_supervised_trainer(model, optimizer, loss_fn, metrics={}, ...@@ -271,21 +271,12 @@ def create_supervised_trainer(model, optimizer, loss_fn, metrics={},
Returns: Returns:
Engine: a trainer engine with supervised update function. Engine: a trainer engine with supervised update function.
""" """
if device: engine = ignite.engine.create_supervised_trainer(model, optimizer, loss_fn, device, non_blocking, prepare_batch,
model.to(device) output_transform = lambda x, y, y_pred, loss: (y_pred, y,))
def _update(engine, batch): @engine.on(Events.ITERATION_STARTED)
def reset_output(engine):
engine.state.output = None engine.state.output = None
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
return y_pred, y
engine = ignite.engine.Engine(_update)
for name, metric in metrics.items(): for name, metric in metrics.items():
metric.attach(engine, 'train:' + name) metric.attach(engine, 'train:' + name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment