Commit bddb0704 authored by 吴磊(20博)'s avatar 吴磊(20博)

init

parent 6fe67be4
......@@ -33,6 +33,7 @@ def train(rank, world_size, args):
Resize(224),
Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.2))
])
dataset = torchvision.datasets.CIFAR100("./CIFAR100", transform=transform)
train_sampler = DistributedSampler(dataset, world_size, rank, True)
data_loader = torch.utils.data.DataLoader(dataset,
......@@ -45,7 +46,11 @@ def train(rank, world_size, args):
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params)
# we can apply multiple loss functions on multiple outputs.
# loss = [{"loss_1_1": loss_1_1, "loss_1_2": loss_1_2}, {"loss_2_1": loss_2_1, "loss_2_2": loss_2_2}]
loss_func = [{"ce": torch.nn.CrossEntropyLoss().to(rank)}]
# the weights of every output.
loss_weights = [1.]
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment