Commit 71888884 authored by IlyaOvodov's avatar IlyaOvodov

refactoring

parent 06d2e8f2
...@@ -41,7 +41,7 @@ class AttrDict(OrderedDict): ...@@ -41,7 +41,7 @@ class AttrDict(OrderedDict):
elif isinstance(v, list): elif isinstance(v, list):
self[k] = [AttrDict(item) if isinstance(item, dict) else item for item in v] self[k] = [AttrDict(item) if isinstance(item, dict) else item for item in v]
def __repr__(self): def __str__(self):
def write_item(item, margin='\n'): def write_item(item, margin='\n'):
if isinstance(item, dict): if isinstance(item, dict):
s = '{' s = '{'
...@@ -62,7 +62,7 @@ class AttrDict(OrderedDict): ...@@ -62,7 +62,7 @@ class AttrDict(OrderedDict):
s += write_item(v, margin=margin + ' ') + "," s += write_item(v, margin=margin + ' ') + ","
s += ' ' + (']' if isinstance(item, list) else ')') s += ' ' + (']' if isinstance(item, list) else ')')
else: else:
s = repr(item) s = str(item)
return s return s
return write_item(self) return write_item(self)
...@@ -107,7 +107,7 @@ class AttrDict(OrderedDict): ...@@ -107,7 +107,7 @@ class AttrDict(OrderedDict):
dir_name = os.path.dirname(params_fn) dir_name = os.path.dirname(params_fn)
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
with open(params_fn, 'w+') as f: with open(params_fn, 'w+') as f:
s = repr(self) s = str(self)
s = s + '\nhash: ' + self.hash() s = s + '\nhash: ' + self.hash()
f.write(s) f.write(s)
if verbose >= 2: if verbose >= 2:
...@@ -135,7 +135,7 @@ class AttrDict(OrderedDict): ...@@ -135,7 +135,7 @@ class AttrDict(OrderedDict):
assert s[-1].startswith('hash:') assert s[-1].startswith('hash:')
params = AttrDict.load_from_str(s[:-1], data_root) params = AttrDict.load_from_str(s[:-1], data_root)
if verbose >= 2: if verbose >= 2:
print('params: '+ repr(params) + '\nhash: ' + params.hash()) print('params: '+ str(params) + '\nhash: ' + params.hash())
if verbose >= 1: if verbose >= 1:
print('loaded from ' + params_fn) print('loaded from ' + params_fn)
return params return params
...@@ -226,7 +226,7 @@ if __name__=='__main__': ...@@ -226,7 +226,7 @@ if __name__=='__main__':
), ),
), ),
) )
print(repr(m)) print(m)
fn = 'test_' + m.hash() fn = 'test_' + m.hash()
m.save(fn, can_overwrite=True) m.save(fn, can_overwrite=True)
......
from .data import CachedDataSet, BatchThreadingDataLoader, ThreadingDataLoader from .data import CachedDataSet, BatchThreadingDataLoader, ThreadingDataLoader
from .losses import CompositeLoss, MeanLoss from .losses import SimpleLoss, CompositeLoss, MeanLoss, LabelSmoothingBCEWithLogitsLoss
from .modules import ReverseLayerF, DANN_module, Dann_Head, DannEncDecNet from .modules import ReverseLayerF, DANN_module, Dann_Head, DannEncDecNet
......
from .composite_loss import CompositeLoss from .composite_loss import SimpleLoss, CompositeLoss
from .mean_loss import MeanLoss from .mean_loss import MeanLoss
from .label_smoothing import LabelSmoothingBCEWithLogitsLoss
from typing import List
import torch import torch
from torch.nn.modules.loss import _Loss
class CompositeLoss(torch.nn.modules.loss._Loss):
def __init__(self, loss_funcs): class SimpleLoss(_Loss):
def __init__(self, loss_func: _Loss, dict_key: str = None):
super(SimpleLoss, self).__init__()
self.loss_func = loss_func
self.dict_key = dict_key
self.val = None
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
if self.dict_key and isinstance(y_true, dict):
y_true = y_true[self.dict_key]
self.val = self.loss_func(y_pred, y_true)
return self.val
def __len__(self):
return 0
def get_val(self):
def call(*kargs, **kwargs):
return self.val
return call
class CompositeLoss(_Loss):
def __init__(self, loss_funcs: List):
super(CompositeLoss, self).__init__() super(CompositeLoss, self).__init__()
self.loss_funcs = loss_funcs self.loss_funcs = loss_funcs
self.loss_vals = [None] * (len(self.loss_funcs) + 1) self.val = None
self.sub_vals = [None] * (len(self.loss_funcs) + 1)
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
self.loss_vals = [loss_fn(y_pred, y_true) for (loss_fn, _,) in self.loss_funcs] self.sub_vals = [loss_fn(y_pred, y_true) for (loss_fn, _,) in self.loss_funcs]
res = sum([w * self.loss_vals[i] for i, (_, w,) in enumerate(self.loss_funcs)]) self.val = sum([w * self.sub_vals[i] for i, (_, w,) in enumerate(self.loss_funcs)])
self.loss_vals.append(res) return self.val
return res
def __len__(self): def __len__(self):
return len(self.loss_funcs) return len(self.loss_funcs)
def get_val(self, index): def get_val(self):
def call(*kargs, **kwargs):
return self.val
return call
def get_subval(self, index):
def call(*kargs, **kwargs): def call(*kargs, **kwargs):
return self.loss_vals[index] return self.sub_vals[index]
return call return call
......
import torch
import torch.nn.functional as F
def LabelSmoothingBCEWithLogitsLoss(label_smoothing = 0, **kwargs):
def loss(y, y_gt):
'''
s = 1/(1+exp(-x))
L_smooth = -(y_gt*log(s) + (1-y_gt)*log(1-s)) = x(1-y_gt) - log(s)
L_min = -(y_gt*log(y_gt) + (1-y_gt)*log(1-y_gt))
L = L_smooth - L_min = KL distance
'''
y_gt = label_smoothing + (1-2*label_smoothing)*y_gt
loss_val = y*(1-y_gt) - F.logsigmoid(y)
if label_smoothing:
loss_val += y_gt*torch.log(y_gt) + (1-y_gt)*torch.log(1-y_gt)
loss_val_mean = loss_val.mean()
return loss_val_mean
return loss
\ No newline at end of file
...@@ -4,34 +4,37 @@ import torch ...@@ -4,34 +4,37 @@ import torch
class MeanLoss(torch.nn.modules.loss._Loss): class MeanLoss(torch.nn.modules.loss._Loss):
''' '''
New loss calculated as mean of base binary loss calculated for all channels separately New loss calculated as mean of base binary loss calculated for all channels separately
loss values for individual channels are stored in get_val loss values for individual channels are stored in sub_vals
''' '''
def __init__(self, base_binary_locc): def __init__(self, base_binary_locc):
super(MeanLoss, self).__init__() super(MeanLoss, self).__init__()
self.base_loss = base_binary_locc self.base_loss = base_binary_locc
self.loss_vals = [] self.sub_vals = []
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
assert y_true.shape == y_pred.shape, (y_pred.shape, y_true.shape) assert y_true.shape == y_pred.shape, (y_pred.shape, y_true.shape)
self.loss_vals = [self.base_loss(y_pred[:, i, ...], y_true[:, i, ...]) for i in range(y_true.shape[1])] self.sub_vals = [self.base_loss(y_pred[:, i, ...], y_true[:, i, ...]) for i in range(y_true.shape[1])]
res = torch.stack(self.loss_vals).mean() self.val = torch.stack(self.sub_vals).mean()
self.loss_vals.append(res) return self.val
return res
def __len__(self): def __len__(self):
''' '''
returns number of individual channel losses (not including the last value stored in self.loss_vals) returns number of individual channel losses
''' '''
return len(self.loss_funcs) return len(self.sub_vals)
def get_val(self, index): def get_val(self):
''' '''
returns function that returns individual channel loss cor channel index returns function to get last result
valid indexes are 0..self.len(). The last index=self.len() or index=1 is mean value returned by forward()
''' '''
def call(*kargs, **kwargs): def call(*kargs, **kwargs):
return self.loss_vals[index] return self.val
return call return call
def get_subval(self, index):
'''
returns function that returns individual channel loss cor channel index
'''
def call(*kargs, **kwargs):
return self.sub_vals[index]
return call
\ No newline at end of file
from typing import Callable from typing import Callable
import torch import torch
from ovotools import AttrDict from ovotools import AttrDict
from ..losses import MeanLoss, CompositeLoss from ..losses import SimpleLoss, CompositeLoss, MeanLoss
def create_object(params: dict, eval_func: Callable = eval, *args, **kwargs) -> object: def create_object(params: dict, eval_func: Callable = eval, *args, **kwargs) -> object:
...@@ -23,7 +23,7 @@ def create_object(params: dict, eval_func: Callable = eval, *args, **kwargs) -> ...@@ -23,7 +23,7 @@ def create_object(params: dict, eval_func: Callable = eval, *args, **kwargs) ->
all_kwargs = kwargs.copy() all_kwargs = kwargs.copy()
p = params.get('params', dict()) p = params.get('params', dict())
all_kwargs.update(p) all_kwargs.update(p)
print('creating: ', params['type'], p) print('creating: ', params['type'], repr(dict(p)))
obj = eval_func(params['type'])(*args, **all_kwargs) obj = eval_func(params['type'])(*args, **all_kwargs)
return obj return obj
...@@ -148,6 +148,7 @@ def CreateCompositeLoss(loss_params: dict, eval_func=eval) -> torch.nn.modules.l ...@@ -148,6 +148,7 @@ def CreateCompositeLoss(loss_params: dict, eval_func=eval) -> torch.nn.modules.l
loss = create_object(loss_params, eval_func) loss = create_object(loss_params, eval_func)
if loss_params.get('mean', False): if loss_params.get('mean', False):
loss = MeanLoss(loss) loss = MeanLoss(loss)
loss = SimpleLoss(loss, loss_params.get('key'))
return loss return loss
else: else:
loss_funcs = [] loss_funcs = []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment