Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
H
hi-template
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
吴磊(20博)
hi-template
Commits
fb1ea036
Commit
fb1ea036
authored
Apr 01, 2022
by
吴磊(20博)
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
init
parent
300bc2b9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
49 additions
and
0 deletions
+49
-0
.gitignore
.gitignore
+3
-0
trainer.py
trainer/trainer.py
+46
-0
__init__.py
utils/__init__.py
+0
-0
utils.py
utils/utils.py
+0
-0
No files found.
.gitignore
0 → 100644
View file @
fb1ea036
.idea
.DS_Store
\ No newline at end of file
trainer/trainer.py
View file @
fb1ea036
import
torch
from
utils.utils
import
*
def
train_one_epoch
(
model
,
optimizer
,
data_loader
,
device
,
epoch
,
loss
,
loss_weights
,
warmup
=
False
):
lr_scheduler
=
None
log_bar
=
None
if
is_main_process
():
log_bar
=
ProgBar
(
len
(
data_loader
))
for
i
,
(
inputs
,
targets
)
in
enumerate
(
data_loader
):
if
isinstance
(
inputs
,
torch
.
Tensor
):
inputs
=
[
inputs
]
if
isinstance
(
targets
,
torch
.
Tensor
):
targets
=
[
targets
]
for
idx_inputs
in
range
(
len
(
inputs
)):
inputs
[
idx_inputs
]
=
inputs
[
idx_inputs
]
.
to
(
device
)
for
idx_target
in
range
(
len
(
targets
)):
targets
[
idx_target
]
=
targets
[
idx_target
]
.
to
(
device
)
outputs
=
model
(
*
inputs
)
if
isinstance
(
outputs
,
torch
.
Tensor
):
outputs
=
[
outputs
]
losses
=
[]
it
=
zip
(
outputs
,
targets
,
loss
,
loss_weights
)
for
output
,
target
,
loss_dict
,
loss_weight
in
it
:
loss_value_dict
=
dict
()
for
loss_name
,
loss_fn
in
loss_dict
.
items
():
loss_value
=
loss_fn
(
output
,
target
)
loss_value_dict
.
update
({
loss_name
:
loss_value
*
loss_weight
})
losses
.
append
(
loss_value_dict
)
losses_value
=
sum
([
sum
(
loss_v
for
loss_v
in
loss_value_dict
.
values
())
for
loss_value_dict
in
losses
])
optimizer
.
zero_grad
()
losses_value
.
backward
()
optimizer
.
step
()
loss_dict_reduced
=
[]
for
loss_value_dict
in
losses
:
loss_dict_reduced
.
append
(
reduce_dict
(
loss_value_dict
))
if
is_main_process
()
and
log_bar
is
not
None
:
logs
=
[]
for
l_i
,
loss_dict
in
enumerate
(
loss_dict_reduced
):
for
k
,
v
in
loss_dict
.
items
():
logs
.
append
((
k
+
"_
%
d"
%
l_i
,
v
.
item
()))
log_bar
.
update
(
i
+
1
,
logs
)
return
None
\ No newline at end of file
utils/__init__.py
0 → 100644
View file @
fb1ea036
utils/utils.py
0 → 100644
View file @
fb1ea036
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment