Commit 71888884 authored by IlyaOvodov's avatar IlyaOvodov

refactoring

parent 06d2e8f2
......@@ -41,7 +41,7 @@ class AttrDict(OrderedDict):
elif isinstance(v, list):
self[k] = [AttrDict(item) if isinstance(item, dict) else item for item in v]
def __repr__(self):
def __str__(self):
def write_item(item, margin='\n'):
if isinstance(item, dict):
s = '{'
......@@ -62,7 +62,7 @@ class AttrDict(OrderedDict):
s += write_item(v, margin=margin + ' ') + ","
s += ' ' + (']' if isinstance(item, list) else ')')
else:
s = repr(item)
s = str(item)
return s
return write_item(self)
......@@ -107,7 +107,7 @@ class AttrDict(OrderedDict):
dir_name = os.path.dirname(params_fn)
os.makedirs(dir_name, exist_ok=True)
with open(params_fn, 'w+') as f:
s = repr(self)
s = str(self)
s = s + '\nhash: ' + self.hash()
f.write(s)
if verbose >= 2:
......@@ -135,7 +135,7 @@ class AttrDict(OrderedDict):
assert s[-1].startswith('hash:')
params = AttrDict.load_from_str(s[:-1], data_root)
if verbose >= 2:
print('params: '+ repr(params) + '\nhash: ' + params.hash())
print('params: '+ str(params) + '\nhash: ' + params.hash())
if verbose >= 1:
print('loaded from ' + params_fn)
return params
......@@ -226,7 +226,7 @@ if __name__=='__main__':
),
),
)
print(repr(m))
print(m)
fn = 'test_' + m.hash()
m.save(fn, can_overwrite=True)
......
from .data import CachedDataSet, BatchThreadingDataLoader, ThreadingDataLoader
from .losses import CompositeLoss, MeanLoss
from .losses import SimpleLoss, CompositeLoss, MeanLoss, LabelSmoothingBCEWithLogitsLoss
from .modules import ReverseLayerF, DANN_module, Dann_Head, DannEncDecNet
......
from .composite_loss import CompositeLoss
from .composite_loss import SimpleLoss, CompositeLoss
from .mean_loss import MeanLoss
from .label_smoothing import LabelSmoothingBCEWithLogitsLoss
from typing import List
import torch
from torch.nn.modules.loss import _Loss
class CompositeLoss(torch.nn.modules.loss._Loss):
def __init__(self, loss_funcs):
class SimpleLoss(_Loss):
def __init__(self, loss_func: _Loss, dict_key: str = None):
super(SimpleLoss, self).__init__()
self.loss_func = loss_func
self.dict_key = dict_key
self.val = None
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
if self.dict_key and isinstance(y_true, dict):
y_true = y_true[self.dict_key]
self.val = self.loss_func(y_pred, y_true)
return self.val
def __len__(self):
return 0
def get_val(self):
def call(*kargs, **kwargs):
return self.val
return call
class CompositeLoss(_Loss):
def __init__(self, loss_funcs: List):
super(CompositeLoss, self).__init__()
self.loss_funcs = loss_funcs
self.loss_vals = [None] * (len(self.loss_funcs) + 1)
self.val = None
self.sub_vals = [None] * (len(self.loss_funcs) + 1)
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
self.loss_vals = [loss_fn(y_pred, y_true) for (loss_fn, _,) in self.loss_funcs]
res = sum([w * self.loss_vals[i] for i, (_, w,) in enumerate(self.loss_funcs)])
self.loss_vals.append(res)
return res
self.sub_vals = [loss_fn(y_pred, y_true) for (loss_fn, _,) in self.loss_funcs]
self.val = sum([w * self.sub_vals[i] for i, (_, w,) in enumerate(self.loss_funcs)])
return self.val
def __len__(self):
return len(self.loss_funcs)
def get_val(self, index):
def get_val(self):
def call(*kargs, **kwargs):
return self.val
return call
def get_subval(self, index):
def call(*kargs, **kwargs):
return self.loss_vals[index]
return self.sub_vals[index]
return call
......
import torch
import torch.nn.functional as F
def LabelSmoothingBCEWithLogitsLoss(label_smoothing = 0, **kwargs):
def loss(y, y_gt):
'''
s = 1/(1+exp(-x))
L_smooth = -(y_gt*log(s) + (1-y_gt)*log(1-s)) = x(1-y_gt) - log(s)
L_min = -(y_gt*log(y_gt) + (1-y_gt)*log(1-y_gt))
L = L_smooth - L_min = KL distance
'''
y_gt = label_smoothing + (1-2*label_smoothing)*y_gt
loss_val = y*(1-y_gt) - F.logsigmoid(y)
if label_smoothing:
loss_val += y_gt*torch.log(y_gt) + (1-y_gt)*torch.log(1-y_gt)
loss_val_mean = loss_val.mean()
return loss_val_mean
return loss
\ No newline at end of file
......@@ -4,34 +4,37 @@ import torch
class MeanLoss(torch.nn.modules.loss._Loss):
'''
New loss calculated as mean of base binary loss calculated for all channels separately
loss values for individual channels are stored in get_val
loss values for individual channels are stored in sub_vals
'''
def __init__(self, base_binary_locc):
super(MeanLoss, self).__init__()
self.base_loss = base_binary_locc
self.loss_vals = []
self.sub_vals = []
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
assert y_true.shape == y_pred.shape, (y_pred.shape, y_true.shape)
self.loss_vals = [self.base_loss(y_pred[:, i, ...], y_true[:, i, ...]) for i in range(y_true.shape[1])]
res = torch.stack(self.loss_vals).mean()
self.loss_vals.append(res)
return res
self.sub_vals = [self.base_loss(y_pred[:, i, ...], y_true[:, i, ...]) for i in range(y_true.shape[1])]
self.val = torch.stack(self.sub_vals).mean()
return self.val
def __len__(self):
'''
returns number of individual channel losses (not including the last value stored in self.loss_vals)
returns number of individual channel losses
'''
return len(self.loss_funcs)
return len(self.sub_vals)
def get_val(self, index):
def get_val(self):
'''
returns function that returns individual channel loss cor channel index
valid indexes are 0..self.len(). The last index=self.len() or index=1 is mean value returned by forward()
returns function to get last result
'''
def call(*kargs, **kwargs):
return self.loss_vals[index]
return self.val
return call
def get_subval(self, index):
'''
returns function that returns individual channel loss cor channel index
'''
def call(*kargs, **kwargs):
return self.sub_vals[index]
return call
\ No newline at end of file
from typing import Callable
import torch
from ovotools import AttrDict
from ..losses import MeanLoss, CompositeLoss
from ..losses import SimpleLoss, CompositeLoss, MeanLoss
def create_object(params: dict, eval_func: Callable = eval, *args, **kwargs) -> object:
......@@ -23,7 +23,7 @@ def create_object(params: dict, eval_func: Callable = eval, *args, **kwargs) ->
all_kwargs = kwargs.copy()
p = params.get('params', dict())
all_kwargs.update(p)
print('creating: ', params['type'], p)
print('creating: ', params['type'], repr(dict(p)))
obj = eval_func(params['type'])(*args, **all_kwargs)
return obj
......@@ -148,6 +148,7 @@ def CreateCompositeLoss(loss_params: dict, eval_func=eval) -> torch.nn.modules.l
loss = create_object(loss_params, eval_func)
if loss_params.get('mean', False):
loss = MeanLoss(loss)
loss = SimpleLoss(loss, loss_params.get('key'))
return loss
else:
loss_funcs = []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment