Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
H
hi-template
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
吴磊(20博)
hi-template
Commits
6fe67be4
Commit
6fe67be4
authored
Apr 01, 2022
by
吴磊(20博)
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
init
parent
e0bd56b7
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
105 additions
and
0 deletions
+105
-0
__init__.py
train/__init__.py
+0
-0
train_mnist.py
train/train_mnist.py
+105
-0
No files found.
train/__init__.py
0 → 100644
View file @
6fe67be4
train/train_mnist.py
0 → 100644
View file @
6fe67be4
import
argparse
import
os
import
random
import
sys
import
torch
import
torchvision
import
numpy
as
np
import
torch.multiprocessing
as
mp
import
torch.nn.functional
as
F
# from segmentation_models_pytorch.losses import FocalLoss
from
torch.optim.lr_scheduler
import
ExponentialLR
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.utils.data
import
DistributedSampler
from
torchvision.transforms
import
*
cur_path
=
os
.
path
.
abspath
(
os
.
path
.
dirname
(
__file__
))
root_path
=
os
.
path
.
split
(
cur_path
)[
0
]
sys
.
path
.
append
(
root_path
)
from
trainer.trainer
import
train_one_epoch
from
utils.utils
import
mkdir
,
save_on_master
,
setup
,
is_main_process
,
reduce_dict
,
cleanup
,
ProgBar
,
\
warmup_lr_scheduler
torch
.
multiprocessing
.
set_sharing_strategy
(
'file_system'
)
def
train
(
rank
,
world_size
,
args
):
setup
(
rank
,
world_size
,
port
=
args
.
port
)
transform
=
Compose
([
ToTensor
(),
Resize
(
224
),
Normalize
((
0.485
,
0.456
,
0.406
),
(
0.229
,
0.224
,
0.2
))
])
dataset
=
torchvision
.
datasets
.
CIFAR100
(
"./CIFAR100"
,
transform
=
transform
)
train_sampler
=
DistributedSampler
(
dataset
,
world_size
,
rank
,
True
)
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
sampler
=
train_sampler
,
num_workers
=
args
.
num_workers
,
persistent_workers
=
True
)
model
=
torchvision
.
models
.
resnet34
(
num_classes
=
100
)
.
to
(
rank
)
params
=
[
p
for
p
in
model
.
parameters
()
if
p
.
requires_grad
]
optimizer
=
torch
.
optim
.
Adam
(
params
)
loss_func
=
[{
"ce"
:
torch
.
nn
.
CrossEntropyLoss
()
.
to
(
rank
)}]
loss_weights
=
[
1.
]
model
=
DDP
(
model
,
device_ids
=
[
rank
],
output_device
=
rank
,
find_unused_parameters
=
False
)
lr_scheduler
=
ExponentialLR
(
optimizer
,
0.9
)
if
args
.
start_epoch
!=
0
:
try
:
checkpoint
=
torch
.
load
(
os
.
path
.
join
(
args
.
checkpoint_path
,
'model_
%.2
d.pth'
%
(
args
.
start_epoch
-
1
)),
map_location
=
'cpu'
)
model
.
module
.
load_state_dict
(
checkpoint
[
'model'
])
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
lr_scheduler
.
load_state_dict
(
checkpoint
[
'lr_scheduler'
])
print
(
'CHECKPOINT LOADED!'
)
except
RuntimeError
as
e
:
print
(
'load error'
)
print
(
e
)
for
epoch
in
range
(
args
.
start_epoch
,
args
.
num_epochs
):
if
is_main_process
():
print
(
"Epoch
%
d/
%
d"
%
(
epoch
+
1
,
args
.
num_epochs
))
model
.
train
()
train_one_epoch
(
model
,
optimizer
,
data_loader
,
rank
,
loss_func
,
loss_weights
)
save_on_master
({
'model'
:
model
.
module
.
state_dict
(),
'optimizer'
:
optimizer
.
state_dict
(),
'lr_scheduler'
:
lr_scheduler
.
state_dict
(),
'epoch'
:
epoch
,
},
os
.
path
.
join
(
args
.
checkpoint_path
,
'model_
%.2
d.pth'
%
epoch
))
cleanup
()
def
run_training
(
train_fn
,
args
,
world_size
):
mkdir
(
args
.
checkpoint_path
)
mp
.
spawn
(
train_fn
,
args
=
(
world_size
,
args
),
nprocs
=
world_size
,
join
=
True
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
'example'
)
parser
.
add_argument
(
'--checkpoint_path'
,
type
=
str
,
default
=
"/data6/wulei/train_log/template/"
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
'--num_workers'
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
'--num_epochs'
,
type
=
int
,
default
=
35
)
parser
.
add_argument
(
'--start_epoch'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--port'
,
type
=
str
,
default
=
"1892"
)
parser
.
add_argument
(
'--devices'
,
type
=
list
,
default
=
[
0
,
1
])
args
=
parser
.
parse_args
()
devices
=
','
.
join
([
str
(
s
)
for
s
in
args
.
devices
])
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
devices
n_gpus
=
torch
.
cuda
.
device_count
()
run_training
(
train
,
args
,
n_gpus
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment