Commit 5b0d0410 authored by IlyaOvodov's avatar IlyaOvodov

from_dict transform, CompositeLoss, DANN

parent b59c3474
from .params import AttrDict
\ No newline at end of file
from .params import AttrDict
import .params
import .pytorch
import .ignite
\ No newline at end of file
from .transforms import from_dict
\ No newline at end of file
def from_dict(key):
'''
Creates transform function extracting value from dict if input is dict
:param key: string
:return: functions: x -> x[key] if x is dict, x otherwise
'''
def call(data):
if isinstance(data, dict):
return data[key]
elif isinstance(data, tuple):
return tuple(call(di) for di in data)
else:
return data
return call
\ No newline at end of file
from .data import CachedDataSet, BatchThreadingDataLoader, ThreadingDataLoader
from .losses import MeanLoss
from .losses import CompositeLoss, MeanLoss
from .modules import ReverseLayerF, DANN_module, Dann_Head, DannEncDecNet
from .utils import set_reproducibility
from .utils import reproducibility_worker_init_fn
from .mean_loss import MeanLoss
\ No newline at end of file
from .composite_loss import CompositeLoss
from .mean_loss import MeanLoss
import torch
class CompositeLoss(torch.nn.modules.loss._Loss):
def __init__(self, loss_funcs):
super(CompositeLoss, self).__init__()
self.loss_funcs = loss_funcs
self.loss_vals = [None] * (len(self.loss_funcs) + 1)
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
self.loss_vals = [loss_fn(y_pred, y_true) for (loss_fn, _,) in self.loss_funcs]
res = sum([w * self.loss_vals[i] for i, (_, w,) in enumerate(self.loss_funcs)])
self.loss_vals.append(res)
return res
def __len__(self):
return len(self.loss_funcs)
def get_val(self, index):
def call(*kargs, **kwargs):
return self.loss_vals[index]
return call
......@@ -9,7 +9,7 @@ class MeanLoss(torch.nn.modules.loss._Loss):
def __init__(self, base_binary_locc):
super(MeanLoss, self).__init__()
self.base_loss = base_binary_locc
self.loss_vals = []
self.loss_vals = []
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
assert y_true.shape == y_pred.shape, (y_pred.shape, y_true.shape)
......@@ -19,16 +19,16 @@ class MeanLoss(torch.nn.modules.loss._Loss):
return res
def __len__(self):
'''
returns number of individual channel losses (not including the last value stored in self.loss_vals)
'''
'''
returns number of individual channel losses (not including the last value stored in self.loss_vals)
'''
return len(self.loss_funcs)
def get_val(self, index):
'''
'''
returns function that returns individual channel loss cor channel index
valid indexes are 0..self.len(). The last index=self.len() or index=1 is mean value returned by forward()
'''
'''
def call(*kargs, **kwargs):
return self.loss_vals[index]
return call
......
from .dann import ReverseLayerF, DANN_module, Dann_Head, DannEncDecNet
\ No newline at end of file
'''
DANN module. See https://arxiv.org/abs/1505.07818, https://arxiv.org/abs/1409.7495
'''
import torch
import torch.nn as nn
from torch.autograd import Function
class ReverseLayerF(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.alpha
return output, None
class DANN_module(nn.Module):
def __init__(self, gamma = 10., lambda_max = 1., **kwargs):
super(DANN_module, self).__init__()
self.gamma = gamma
self.lambda_max = lambda_max
self.progress = nn.Parameter(torch.Tensor([0,]), requires_grad = False) # must be updated from outside
def set_progress(self, value):
self.progress.data = torch.Tensor([value,]).to(self.progress.device)
def forward(self, input_data):
lambda_p = self.lambda_max * (2. / (1. + torch.exp(-self.gamma * self.progress.data)) - 1)
reverse_feature = ReverseLayerF.apply(input_data, lambda_p)
return reverse_feature
class Dann_Head(nn.Module):
def __init__(self, input_dims, num_classes, **kwargs):
super(Dann_Head, self).__init__()
self.input_dims = input_dims
self.pooling_depth = 100
self.dann_module = DANN_module(**kwargs)
self.pooling_modules = [nn.Sequential(
nn.Conv2d(dim, self.pooling_depth, 3),
nn.AdaptiveMaxPool2d(1)
)
for dim in input_dims]
for i, m in enumerate(self.pooling_modules):
self.add_module("pooling_module"+str(i), m)
self.domain_classifier = nn.Sequential()
self.domain_classifier.add_module('dann_fc1', nn.Linear(self.pooling_depth*len(input_dims), 100))
self.domain_classifier.add_module('dann_bn1', nn.BatchNorm1d(100)) # - dows not work for batch = 1
#self.domain_classifier.add_module('dann_bn1', nn.GroupNorm(5, 100)) # - dows not work for batch = 1
self.domain_classifier.add_module('dann_relu1', nn.ReLU(True))
self.domain_classifier.add_module('dann_fc2', nn.Linear(100, num_classes))
self.domain_classifier.add_module('dann_softmax', nn.LogSoftmax(dim=1))
self.loss = nn.NLLLoss()
def set_progress(self, value):
self.dann_module.set_progress(value)
def forward(self, inputs, y_true: torch.Tensor):
features = []
for i, x in enumerate(inputs[:len(self.input_dims)]):
assert x.shape[1] == self.input_dims[i]
x = self.dann_module(x)
x = self.pooling_modules[i](x)
features.append(x)
feature = torch.cat([f.flatten(start_dim=1) for f in features], dim=1)
domain_output = self.domain_classifier(feature)
loss = self.loss(domain_output, y_true.long())
return loss
class DannEncDecNet(nn.Module):
def __init__(self, base_net, input_dim, num_classes, **enc_dec_params):
super(DannEncDecNet, self).__init__()
self.net = base_net
self.dann_head = Dann_Head(input_dim, num_classes, **enc_dec_params)
self.bottleneck_data = None
self.dann_loss = None
def set_progress(self, value):
self.dann_head.set_progress(value)
def forward(self, x):
"""Sequentially pass `x` trough model`s `encoder` and `decoder` (return logits!)"""
x = self.net.encoder(x)
self.bottleneck_data = x
x = self.net.decoder(x)
return x
def calc_dann_loss(self, y_pred, y_true):
self.dann_loss = self.dann_head(self.bottleneck_data, y_true)
return self.dann_loss
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment