Commit 6fe67be4 authored by 吴磊(20博)'s avatar 吴磊(20博)

init

parent e0bd56b7
import argparse
import os
import random
import sys
import torch
import torchvision
import numpy as np
import torch.multiprocessing as mp
import torch.nn.functional as F
# from segmentation_models_pytorch.losses import FocalLoss
from torch.optim.lr_scheduler import ExponentialLR
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DistributedSampler
from torchvision.transforms import *
cur_path = os.path.abspath(os.path.dirname(__file__))
root_path = os.path.split(cur_path)[0]
sys.path.append(root_path)
from trainer.trainer import train_one_epoch
from utils.utils import mkdir, save_on_master, setup, is_main_process, reduce_dict, cleanup, ProgBar, \
warmup_lr_scheduler
torch.multiprocessing.set_sharing_strategy('file_system')
def train(rank, world_size, args):
setup(rank, world_size, port=args.port)
transform = Compose([
ToTensor(),
Resize(224),
Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.2))
])
dataset = torchvision.datasets.CIFAR100("./CIFAR100", transform=transform)
train_sampler = DistributedSampler(dataset, world_size, rank, True)
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=args.num_workers,
persistent_workers=True)
model = torchvision.models.resnet34(num_classes=100).to(rank)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params)
loss_func = [{"ce": torch.nn.CrossEntropyLoss().to(rank)}]
loss_weights = [1.]
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
lr_scheduler = ExponentialLR(optimizer, 0.9)
if args.start_epoch != 0:
try:
checkpoint = torch.load(os.path.join(args.checkpoint_path, 'model_%.2d.pth' % (args.start_epoch - 1)),
map_location='cpu')
model.module.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
print('CHECKPOINT LOADED!')
except RuntimeError as e:
print('load error')
print(e)
for epoch in range(args.start_epoch, args.num_epochs):
if is_main_process():
print("Epoch %d/%d" % (epoch + 1, args.num_epochs))
model.train()
train_one_epoch(model, optimizer, data_loader, rank, loss_func, loss_weights)
save_on_master({
'model': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
},
os.path.join(args.checkpoint_path, 'model_%.2d.pth' % epoch))
cleanup()
def run_training(train_fn, args, world_size):
mkdir(args.checkpoint_path)
mp.spawn(train_fn,
args=(world_size, args),
nprocs=world_size,
join=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser('example')
parser.add_argument('--checkpoint_path', type=str, default="/data6/wulei/train_log/template/")
parser.add_argument('--batch_size', type=int, default=10)
parser.add_argument('--num_workers', type=int, default=2)
parser.add_argument('--num_epochs', type=int, default=35)
parser.add_argument('--start_epoch', type=int, default=0)
parser.add_argument('--port', type=str, default="1892")
parser.add_argument('--devices', type=list, default=[0,1])
args = parser.parse_args()
devices = ','.join([str(s) for s in args.devices])
os.environ["CUDA_VISIBLE_DEVICES"] = devices
n_gpus = torch.cuda.device_count()
run_training(train, args, n_gpus)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment