Commit 0574dec4 authored by IlyaOvodov's avatar IlyaOvodov

set_reproducibility, MeanLoss

parent 1919f600
from .threading_dataloader import BatchThreadingDataLoader, ThreadingDataLoader from .threading_dataloader import BatchThreadingDataLoader, ThreadingDataLoader
from .cached_dataset import CachedDataSet from .cached_dataset import CachedDataSet
from .losses import MeanLoss
from .utils import set_reproducibility
from .utils import reproducibility_worker_init_fn
from .mean_loss import MeanLoss
\ No newline at end of file
import torch
class MeanLoss(torch.nn.modules.loss._Loss):
'''
New loss calculated as mean of base binary loss calculated for all channels separately
loss values for individual channels are stored in get_val
'''
def __init__(self, base_binary_locc):
super(MeanLoss, self).__init__()
self.base_loss = base_binary_locc
self.loss_vals = []
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
assert y_true.shape == y_pred.shape, (y_pred.shape, y_true.shape)
self.loss_vals = [self.base_loss(y_pred[:, i, ...], y_true[:, i, ...]) for i in range(y_true.shape[1])]
res = torch.stack(self.loss_vals).mean()
self.loss_vals.append(res)
return res
def __len__(self):
'''
returns number of individual channel losses (not including the last value stored in self.loss_vals)
'''
return len(self.loss_funcs)
def get_val(self, index):
'''
returns function that returns individual channel loss cor channel index
valid indexes are 0..self.len(). The last index=self.len() or index=1 is mean value returned by forward()
'''
def call(*kargs, **kwargs):
return self.loss_vals[index]
return call
from .reproducibility import set_reproducibility, reproducibility_worker_init_fn
\ No newline at end of file
import random
import numpy as np
import torch
SEED = 241075
def set_reproducibility(seed = SEED):
'''
attempts to make calculations reproducible
'''
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic=True
def reproducibility_worker_init_fn(seed = SEED):
def worker_init_fn(worker_id):
np.random.seed(SEED)
return worker_init_fn
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment