Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
O
OvoTools
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
林帅浩
OvoTools
Commits
71888884
Commit
71888884
authored
Oct 13, 2019
by
IlyaOvodov
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
refactoring
parent
06d2e8f2
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
87 additions
and
32 deletions
+87
-32
params.py
ovotools/params/params.py
+5
-5
__init__.py
ovotools/pytorch/__init__.py
+1
-1
__init__.py
ovotools/pytorch/losses/__init__.py
+2
-1
composite_loss.py
ovotools/pytorch/losses/composite_loss.py
+39
-9
label_smoothing.py
ovotools/pytorch/losses/label_smoothing.py
+19
-0
mean_loss.py
ovotools/pytorch/losses/mean_loss.py
+18
-14
create_object.py
ovotools/pytorch/utils/create_object.py
+3
-2
No files found.
ovotools/params/params.py
View file @
71888884
...
@@ -41,7 +41,7 @@ class AttrDict(OrderedDict):
...
@@ -41,7 +41,7 @@ class AttrDict(OrderedDict):
elif
isinstance
(
v
,
list
):
elif
isinstance
(
v
,
list
):
self
[
k
]
=
[
AttrDict
(
item
)
if
isinstance
(
item
,
dict
)
else
item
for
item
in
v
]
self
[
k
]
=
[
AttrDict
(
item
)
if
isinstance
(
item
,
dict
)
else
item
for
item
in
v
]
def
__
rep
r__
(
self
):
def
__
st
r__
(
self
):
def
write_item
(
item
,
margin
=
'
\n
'
):
def
write_item
(
item
,
margin
=
'
\n
'
):
if
isinstance
(
item
,
dict
):
if
isinstance
(
item
,
dict
):
s
=
'{'
s
=
'{'
...
@@ -62,7 +62,7 @@ class AttrDict(OrderedDict):
...
@@ -62,7 +62,7 @@ class AttrDict(OrderedDict):
s
+=
write_item
(
v
,
margin
=
margin
+
' '
)
+
","
s
+=
write_item
(
v
,
margin
=
margin
+
' '
)
+
","
s
+=
' '
+
(
']'
if
isinstance
(
item
,
list
)
else
')'
)
s
+=
' '
+
(
']'
if
isinstance
(
item
,
list
)
else
')'
)
else
:
else
:
s
=
rep
r
(
item
)
s
=
st
r
(
item
)
return
s
return
s
return
write_item
(
self
)
return
write_item
(
self
)
...
@@ -107,7 +107,7 @@ class AttrDict(OrderedDict):
...
@@ -107,7 +107,7 @@ class AttrDict(OrderedDict):
dir_name
=
os
.
path
.
dirname
(
params_fn
)
dir_name
=
os
.
path
.
dirname
(
params_fn
)
os
.
makedirs
(
dir_name
,
exist_ok
=
True
)
os
.
makedirs
(
dir_name
,
exist_ok
=
True
)
with
open
(
params_fn
,
'w+'
)
as
f
:
with
open
(
params_fn
,
'w+'
)
as
f
:
s
=
rep
r
(
self
)
s
=
st
r
(
self
)
s
=
s
+
'
\n
hash: '
+
self
.
hash
()
s
=
s
+
'
\n
hash: '
+
self
.
hash
()
f
.
write
(
s
)
f
.
write
(
s
)
if
verbose
>=
2
:
if
verbose
>=
2
:
...
@@ -135,7 +135,7 @@ class AttrDict(OrderedDict):
...
@@ -135,7 +135,7 @@ class AttrDict(OrderedDict):
assert
s
[
-
1
]
.
startswith
(
'hash:'
)
assert
s
[
-
1
]
.
startswith
(
'hash:'
)
params
=
AttrDict
.
load_from_str
(
s
[:
-
1
],
data_root
)
params
=
AttrDict
.
load_from_str
(
s
[:
-
1
],
data_root
)
if
verbose
>=
2
:
if
verbose
>=
2
:
print
(
'params: '
+
rep
r
(
params
)
+
'
\n
hash: '
+
params
.
hash
())
print
(
'params: '
+
st
r
(
params
)
+
'
\n
hash: '
+
params
.
hash
())
if
verbose
>=
1
:
if
verbose
>=
1
:
print
(
'loaded from '
+
params_fn
)
print
(
'loaded from '
+
params_fn
)
return
params
return
params
...
@@ -226,7 +226,7 @@ if __name__=='__main__':
...
@@ -226,7 +226,7 @@ if __name__=='__main__':
),
),
),
),
)
)
print
(
repr
(
m
)
)
print
(
m
)
fn
=
'test_'
+
m
.
hash
()
fn
=
'test_'
+
m
.
hash
()
m
.
save
(
fn
,
can_overwrite
=
True
)
m
.
save
(
fn
,
can_overwrite
=
True
)
...
...
ovotools/pytorch/__init__.py
View file @
71888884
from
.data
import
CachedDataSet
,
BatchThreadingDataLoader
,
ThreadingDataLoader
from
.data
import
CachedDataSet
,
BatchThreadingDataLoader
,
ThreadingDataLoader
from
.losses
import
CompositeLoss
,
Mean
Loss
from
.losses
import
SimpleLoss
,
CompositeLoss
,
MeanLoss
,
LabelSmoothingBCEWithLogits
Loss
from
.modules
import
ReverseLayerF
,
DANN_module
,
Dann_Head
,
DannEncDecNet
from
.modules
import
ReverseLayerF
,
DANN_module
,
Dann_Head
,
DannEncDecNet
...
...
ovotools/pytorch/losses/__init__.py
View file @
71888884
from
.composite_loss
import
CompositeLoss
from
.composite_loss
import
SimpleLoss
,
CompositeLoss
from
.mean_loss
import
MeanLoss
from
.mean_loss
import
MeanLoss
from
.label_smoothing
import
LabelSmoothingBCEWithLogitsLoss
ovotools/pytorch/losses/composite_loss.py
View file @
71888884
from
typing
import
List
import
torch
import
torch
from
torch.nn.modules.loss
import
_Loss
class
CompositeLoss
(
torch
.
nn
.
modules
.
loss
.
_Loss
):
def
__init__
(
self
,
loss_funcs
):
class
SimpleLoss
(
_Loss
):
def
__init__
(
self
,
loss_func
:
_Loss
,
dict_key
:
str
=
None
):
super
(
SimpleLoss
,
self
)
.
__init__
()
self
.
loss_func
=
loss_func
self
.
dict_key
=
dict_key
self
.
val
=
None
def
forward
(
self
,
y_pred
:
torch
.
Tensor
,
y_true
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
dict_key
and
isinstance
(
y_true
,
dict
):
y_true
=
y_true
[
self
.
dict_key
]
self
.
val
=
self
.
loss_func
(
y_pred
,
y_true
)
return
self
.
val
def
__len__
(
self
):
return
0
def
get_val
(
self
):
def
call
(
*
kargs
,
**
kwargs
):
return
self
.
val
return
call
class
CompositeLoss
(
_Loss
):
def
__init__
(
self
,
loss_funcs
:
List
):
super
(
CompositeLoss
,
self
)
.
__init__
()
super
(
CompositeLoss
,
self
)
.
__init__
()
self
.
loss_funcs
=
loss_funcs
self
.
loss_funcs
=
loss_funcs
self
.
loss_vals
=
[
None
]
*
(
len
(
self
.
loss_funcs
)
+
1
)
self
.
val
=
None
self
.
sub_vals
=
[
None
]
*
(
len
(
self
.
loss_funcs
)
+
1
)
def
forward
(
self
,
y_pred
:
torch
.
Tensor
,
y_true
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
y_pred
:
torch
.
Tensor
,
y_true
:
torch
.
Tensor
)
->
torch
.
Tensor
:
self
.
loss_vals
=
[
loss_fn
(
y_pred
,
y_true
)
for
(
loss_fn
,
_
,)
in
self
.
loss_funcs
]
self
.
sub_vals
=
[
loss_fn
(
y_pred
,
y_true
)
for
(
loss_fn
,
_
,)
in
self
.
loss_funcs
]
res
=
sum
([
w
*
self
.
loss_vals
[
i
]
for
i
,
(
_
,
w
,)
in
enumerate
(
self
.
loss_funcs
)])
self
.
val
=
sum
([
w
*
self
.
sub_vals
[
i
]
for
i
,
(
_
,
w
,)
in
enumerate
(
self
.
loss_funcs
)])
self
.
loss_vals
.
append
(
res
)
return
self
.
val
return
res
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
loss_funcs
)
return
len
(
self
.
loss_funcs
)
def
get_val
(
self
,
index
):
def
get_val
(
self
):
def
call
(
*
kargs
,
**
kwargs
):
return
self
.
val
return
call
def
get_subval
(
self
,
index
):
def
call
(
*
kargs
,
**
kwargs
):
def
call
(
*
kargs
,
**
kwargs
):
return
self
.
loss
_vals
[
index
]
return
self
.
sub
_vals
[
index
]
return
call
return
call
...
...
ovotools/pytorch/losses/label_smoothing.py
0 → 100644
View file @
71888884
import
torch
import
torch.nn.functional
as
F
def
LabelSmoothingBCEWithLogitsLoss
(
label_smoothing
=
0
,
**
kwargs
):
def
loss
(
y
,
y_gt
):
'''
s = 1/(1+exp(-x))
L_smooth = -(y_gt*log(s) + (1-y_gt)*log(1-s)) = x(1-y_gt) - log(s)
L_min = -(y_gt*log(y_gt) + (1-y_gt)*log(1-y_gt))
L = L_smooth - L_min = KL distance
'''
y_gt
=
label_smoothing
+
(
1
-
2
*
label_smoothing
)
*
y_gt
loss_val
=
y
*
(
1
-
y_gt
)
-
F
.
logsigmoid
(
y
)
if
label_smoothing
:
loss_val
+=
y_gt
*
torch
.
log
(
y_gt
)
+
(
1
-
y_gt
)
*
torch
.
log
(
1
-
y_gt
)
loss_val_mean
=
loss_val
.
mean
()
return
loss_val_mean
return
loss
\ No newline at end of file
ovotools/pytorch/losses/mean_loss.py
View file @
71888884
...
@@ -4,34 +4,37 @@ import torch
...
@@ -4,34 +4,37 @@ import torch
class
MeanLoss
(
torch
.
nn
.
modules
.
loss
.
_Loss
):
class
MeanLoss
(
torch
.
nn
.
modules
.
loss
.
_Loss
):
'''
'''
New loss calculated as mean of base binary loss calculated for all channels separately
New loss calculated as mean of base binary loss calculated for all channels separately
loss values for individual channels are stored in
get_val
loss values for individual channels are stored in
sub_vals
'''
'''
def
__init__
(
self
,
base_binary_locc
):
def
__init__
(
self
,
base_binary_locc
):
super
(
MeanLoss
,
self
)
.
__init__
()
super
(
MeanLoss
,
self
)
.
__init__
()
self
.
base_loss
=
base_binary_locc
self
.
base_loss
=
base_binary_locc
self
.
loss
_vals
=
[]
self
.
sub
_vals
=
[]
def
forward
(
self
,
y_pred
:
torch
.
Tensor
,
y_true
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
y_pred
:
torch
.
Tensor
,
y_true
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
y_true
.
shape
==
y_pred
.
shape
,
(
y_pred
.
shape
,
y_true
.
shape
)
assert
y_true
.
shape
==
y_pred
.
shape
,
(
y_pred
.
shape
,
y_true
.
shape
)
self
.
loss_vals
=
[
self
.
base_loss
(
y_pred
[:,
i
,
...
],
y_true
[:,
i
,
...
])
for
i
in
range
(
y_true
.
shape
[
1
])]
self
.
sub_vals
=
[
self
.
base_loss
(
y_pred
[:,
i
,
...
],
y_true
[:,
i
,
...
])
for
i
in
range
(
y_true
.
shape
[
1
])]
res
=
torch
.
stack
(
self
.
loss_vals
)
.
mean
()
self
.
val
=
torch
.
stack
(
self
.
sub_vals
)
.
mean
()
self
.
loss_vals
.
append
(
res
)
return
self
.
val
return
res
def
__len__
(
self
):
def
__len__
(
self
):
'''
'''
returns number of individual channel losses
(not including the last value stored in self.loss_vals)
returns number of individual channel losses
'''
'''
return
len
(
self
.
loss_func
s
)
return
len
(
self
.
sub_val
s
)
def
get_val
(
self
,
index
):
def
get_val
(
self
):
'''
'''
returns function that returns individual channel loss cor channel index
returns function to get last result
valid indexes are 0..self.len(). The last index=self.len() or index=1 is mean value returned by forward()
'''
'''
def
call
(
*
kargs
,
**
kwargs
):
def
call
(
*
kargs
,
**
kwargs
):
return
self
.
loss_vals
[
index
]
return
self
.
val
return
call
return
call
def
get_subval
(
self
,
index
):
'''
returns function that returns individual channel loss cor channel index
'''
def
call
(
*
kargs
,
**
kwargs
):
return
self
.
sub_vals
[
index
]
return
call
\ No newline at end of file
ovotools/pytorch/utils/create_object.py
View file @
71888884
from
typing
import
Callable
from
typing
import
Callable
import
torch
import
torch
from
ovotools
import
AttrDict
from
ovotools
import
AttrDict
from
..losses
import
MeanLoss
,
Composite
Loss
from
..losses
import
SimpleLoss
,
CompositeLoss
,
Mean
Loss
def
create_object
(
params
:
dict
,
eval_func
:
Callable
=
eval
,
*
args
,
**
kwargs
)
->
object
:
def
create_object
(
params
:
dict
,
eval_func
:
Callable
=
eval
,
*
args
,
**
kwargs
)
->
object
:
...
@@ -23,7 +23,7 @@ def create_object(params: dict, eval_func: Callable = eval, *args, **kwargs) ->
...
@@ -23,7 +23,7 @@ def create_object(params: dict, eval_func: Callable = eval, *args, **kwargs) ->
all_kwargs
=
kwargs
.
copy
()
all_kwargs
=
kwargs
.
copy
()
p
=
params
.
get
(
'params'
,
dict
())
p
=
params
.
get
(
'params'
,
dict
())
all_kwargs
.
update
(
p
)
all_kwargs
.
update
(
p
)
print
(
'creating: '
,
params
[
'type'
],
p
)
print
(
'creating: '
,
params
[
'type'
],
repr
(
dict
(
p
))
)
obj
=
eval_func
(
params
[
'type'
])(
*
args
,
**
all_kwargs
)
obj
=
eval_func
(
params
[
'type'
])(
*
args
,
**
all_kwargs
)
return
obj
return
obj
...
@@ -148,6 +148,7 @@ def CreateCompositeLoss(loss_params: dict, eval_func=eval) -> torch.nn.modules.l
...
@@ -148,6 +148,7 @@ def CreateCompositeLoss(loss_params: dict, eval_func=eval) -> torch.nn.modules.l
loss
=
create_object
(
loss_params
,
eval_func
)
loss
=
create_object
(
loss_params
,
eval_func
)
if
loss_params
.
get
(
'mean'
,
False
):
if
loss_params
.
get
(
'mean'
,
False
):
loss
=
MeanLoss
(
loss
)
loss
=
MeanLoss
(
loss
)
loss
=
SimpleLoss
(
loss
,
loss_params
.
get
(
'key'
))
return
loss
return
loss
else
:
else
:
loss_funcs
=
[]
loss_funcs
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment