Commit fb1ea036 authored by 吴磊(20博)'s avatar 吴磊(20博)

init

parent 300bc2b9
.idea
.DS_Store
\ No newline at end of file
import torch
from utils.utils import *
def train_one_epoch(model, optimizer, data_loader, device, epoch, loss, loss_weights, warmup=False):
lr_scheduler = None
log_bar = None
if is_main_process():
log_bar = ProgBar(len(data_loader))
for i, (inputs, targets) in enumerate(data_loader):
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
if isinstance(targets, torch.Tensor):
targets = [targets]
for idx_inputs in range(len(inputs)):
inputs[idx_inputs] = inputs[idx_inputs].to(device)
for idx_target in range(len(targets)):
targets[idx_target] = targets[idx_target].to(device)
outputs = model(*inputs)
if isinstance(outputs, torch.Tensor):
outputs = [outputs]
losses = []
it = zip(outputs, targets, loss, loss_weights)
for output, target, loss_dict, loss_weight in it:
loss_value_dict = dict()
for loss_name, loss_fn in loss_dict.items():
loss_value = loss_fn(output, target)
loss_value_dict.update({loss_name: loss_value * loss_weight})
losses.append(loss_value_dict)
losses_value = sum([sum(loss_v for loss_v in loss_value_dict.values()) for loss_value_dict in losses])
optimizer.zero_grad()
losses_value.backward()
optimizer.step()
loss_dict_reduced = []
for loss_value_dict in losses:
loss_dict_reduced.append(reduce_dict(loss_value_dict))
if is_main_process() and log_bar is not None:
logs = []
for l_i, loss_dict in enumerate(loss_dict_reduced):
for k, v in loss_dict.items():
logs.append((k + "_%d" % l_i, v.item()))
log_bar.update(i + 1, logs)
return None
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment