Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
O
OvoTools
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
林帅浩
OvoTools
Commits
0574dec4
Commit
0574dec4
authored
Oct 11, 2019
by
IlyaOvodov
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
set_reproducibility, MeanLoss
parent
1919f600
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
69 additions
and
0 deletions
+69
-0
__init__.py
ovotools/pytorch/__init__.py
+5
-0
__init__.py
ovotools/pytorch/losses/__init__.py
+2
-0
mean_loss.py
ovotools/pytorch/losses/mean_loss.py
+37
-0
__init__.py
ovotools/pytorch/utils/__init__.py
+2
-0
reproducibility.py
ovotools/pytorch/utils/reproducibility.py
+23
-0
No files found.
ovotools/pytorch/__init__.py
View file @
0574dec4
from
.threading_dataloader
import
BatchThreadingDataLoader
,
ThreadingDataLoader
from
.cached_dataset
import
CachedDataSet
from
.losses
import
MeanLoss
from
.utils
import
set_reproducibility
from
.utils
import
reproducibility_worker_init_fn
ovotools/pytorch/losses/__init__.py
0 → 100644
View file @
0574dec4
from
.mean_loss
import
MeanLoss
\ No newline at end of file
ovotools/pytorch/losses/mean_loss.py
0 → 100644
View file @
0574dec4
import
torch
class
MeanLoss
(
torch
.
nn
.
modules
.
loss
.
_Loss
):
'''
New loss calculated as mean of base binary loss calculated for all channels separately
loss values for individual channels are stored in get_val
'''
def
__init__
(
self
,
base_binary_locc
):
super
(
MeanLoss
,
self
)
.
__init__
()
self
.
base_loss
=
base_binary_locc
self
.
loss_vals
=
[]
def
forward
(
self
,
y_pred
:
torch
.
Tensor
,
y_true
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
y_true
.
shape
==
y_pred
.
shape
,
(
y_pred
.
shape
,
y_true
.
shape
)
self
.
loss_vals
=
[
self
.
base_loss
(
y_pred
[:,
i
,
...
],
y_true
[:,
i
,
...
])
for
i
in
range
(
y_true
.
shape
[
1
])]
res
=
torch
.
stack
(
self
.
loss_vals
)
.
mean
()
self
.
loss_vals
.
append
(
res
)
return
res
def
__len__
(
self
):
'''
returns number of individual channel losses (not including the last value stored in self.loss_vals)
'''
return
len
(
self
.
loss_funcs
)
def
get_val
(
self
,
index
):
'''
returns function that returns individual channel loss cor channel index
valid indexes are 0..self.len(). The last index=self.len() or index=1 is mean value returned by forward()
'''
def
call
(
*
kargs
,
**
kwargs
):
return
self
.
loss_vals
[
index
]
return
call
ovotools/pytorch/utils/__init__.py
0 → 100644
View file @
0574dec4
from
.reproducibility
import
set_reproducibility
,
reproducibility_worker_init_fn
\ No newline at end of file
ovotools/pytorch/utils/reproducibility.py
0 → 100644
View file @
0574dec4
import
random
import
numpy
as
np
import
torch
SEED
=
241075
def
set_reproducibility
(
seed
=
SEED
):
'''
attempts to make calculations reproducible
'''
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
def
reproducibility_worker_init_fn
(
seed
=
SEED
):
def
worker_init_fn
(
worker_id
):
np
.
random
.
seed
(
SEED
)
return
worker_init_fn
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment