Commit 5b0d0410 authored by IlyaOvodov's avatar IlyaOvodov

from_dict transform, CompositeLoss, DANN

parent b59c3474
from .params import AttrDict from .params import AttrDict
\ No newline at end of file import .params
import .pytorch
import .ignite
\ No newline at end of file
from .transforms import from_dict
\ No newline at end of file
def from_dict(key):
'''
Creates transform function extracting value from dict if input is dict
:param key: string
:return: functions: x -> x[key] if x is dict, x otherwise
'''
def call(data):
if isinstance(data, dict):
return data[key]
elif isinstance(data, tuple):
return tuple(call(di) for di in data)
else:
return data
return call
\ No newline at end of file
from .data import CachedDataSet, BatchThreadingDataLoader, ThreadingDataLoader from .data import CachedDataSet, BatchThreadingDataLoader, ThreadingDataLoader
from .losses import MeanLoss from .losses import CompositeLoss, MeanLoss
from .modules import ReverseLayerF, DANN_module, Dann_Head, DannEncDecNet
from .utils import set_reproducibility from .utils import set_reproducibility
from .utils import reproducibility_worker_init_fn from .utils import reproducibility_worker_init_fn
from .mean_loss import MeanLoss from .composite_loss import CompositeLoss
\ No newline at end of file from .mean_loss import MeanLoss
import torch
class CompositeLoss(torch.nn.modules.loss._Loss):
def __init__(self, loss_funcs):
super(CompositeLoss, self).__init__()
self.loss_funcs = loss_funcs
self.loss_vals = [None] * (len(self.loss_funcs) + 1)
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
self.loss_vals = [loss_fn(y_pred, y_true) for (loss_fn, _,) in self.loss_funcs]
res = sum([w * self.loss_vals[i] for i, (_, w,) in enumerate(self.loss_funcs)])
self.loss_vals.append(res)
return res
def __len__(self):
return len(self.loss_funcs)
def get_val(self, index):
def call(*kargs, **kwargs):
return self.loss_vals[index]
return call
...@@ -9,7 +9,7 @@ class MeanLoss(torch.nn.modules.loss._Loss): ...@@ -9,7 +9,7 @@ class MeanLoss(torch.nn.modules.loss._Loss):
def __init__(self, base_binary_locc): def __init__(self, base_binary_locc):
super(MeanLoss, self).__init__() super(MeanLoss, self).__init__()
self.base_loss = base_binary_locc self.base_loss = base_binary_locc
self.loss_vals = [] self.loss_vals = []
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
assert y_true.shape == y_pred.shape, (y_pred.shape, y_true.shape) assert y_true.shape == y_pred.shape, (y_pred.shape, y_true.shape)
...@@ -19,16 +19,16 @@ class MeanLoss(torch.nn.modules.loss._Loss): ...@@ -19,16 +19,16 @@ class MeanLoss(torch.nn.modules.loss._Loss):
return res return res
def __len__(self): def __len__(self):
''' '''
returns number of individual channel losses (not including the last value stored in self.loss_vals) returns number of individual channel losses (not including the last value stored in self.loss_vals)
''' '''
return len(self.loss_funcs) return len(self.loss_funcs)
def get_val(self, index): def get_val(self, index):
''' '''
returns function that returns individual channel loss cor channel index returns function that returns individual channel loss cor channel index
valid indexes are 0..self.len(). The last index=self.len() or index=1 is mean value returned by forward() valid indexes are 0..self.len(). The last index=self.len() or index=1 is mean value returned by forward()
''' '''
def call(*kargs, **kwargs): def call(*kargs, **kwargs):
return self.loss_vals[index] return self.loss_vals[index]
return call return call
......
from .dann import ReverseLayerF, DANN_module, Dann_Head, DannEncDecNet
\ No newline at end of file
'''
DANN module. See https://arxiv.org/abs/1505.07818, https://arxiv.org/abs/1409.7495
'''
import torch
import torch.nn as nn
from torch.autograd import Function
class ReverseLayerF(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.alpha
return output, None
class DANN_module(nn.Module):
def __init__(self, gamma = 10., lambda_max = 1., **kwargs):
super(DANN_module, self).__init__()
self.gamma = gamma
self.lambda_max = lambda_max
self.progress = nn.Parameter(torch.Tensor([0,]), requires_grad = False) # must be updated from outside
def set_progress(self, value):
self.progress.data = torch.Tensor([value,]).to(self.progress.device)
def forward(self, input_data):
lambda_p = self.lambda_max * (2. / (1. + torch.exp(-self.gamma * self.progress.data)) - 1)
reverse_feature = ReverseLayerF.apply(input_data, lambda_p)
return reverse_feature
class Dann_Head(nn.Module):
def __init__(self, input_dims, num_classes, **kwargs):
super(Dann_Head, self).__init__()
self.input_dims = input_dims
self.pooling_depth = 100
self.dann_module = DANN_module(**kwargs)
self.pooling_modules = [nn.Sequential(
nn.Conv2d(dim, self.pooling_depth, 3),
nn.AdaptiveMaxPool2d(1)
)
for dim in input_dims]
for i, m in enumerate(self.pooling_modules):
self.add_module("pooling_module"+str(i), m)
self.domain_classifier = nn.Sequential()
self.domain_classifier.add_module('dann_fc1', nn.Linear(self.pooling_depth*len(input_dims), 100))
self.domain_classifier.add_module('dann_bn1', nn.BatchNorm1d(100)) # - dows not work for batch = 1
#self.domain_classifier.add_module('dann_bn1', nn.GroupNorm(5, 100)) # - dows not work for batch = 1
self.domain_classifier.add_module('dann_relu1', nn.ReLU(True))
self.domain_classifier.add_module('dann_fc2', nn.Linear(100, num_classes))
self.domain_classifier.add_module('dann_softmax', nn.LogSoftmax(dim=1))
self.loss = nn.NLLLoss()
def set_progress(self, value):
self.dann_module.set_progress(value)
def forward(self, inputs, y_true: torch.Tensor):
features = []
for i, x in enumerate(inputs[:len(self.input_dims)]):
assert x.shape[1] == self.input_dims[i]
x = self.dann_module(x)
x = self.pooling_modules[i](x)
features.append(x)
feature = torch.cat([f.flatten(start_dim=1) for f in features], dim=1)
domain_output = self.domain_classifier(feature)
loss = self.loss(domain_output, y_true.long())
return loss
class DannEncDecNet(nn.Module):
def __init__(self, base_net, input_dim, num_classes, **enc_dec_params):
super(DannEncDecNet, self).__init__()
self.net = base_net
self.dann_head = Dann_Head(input_dim, num_classes, **enc_dec_params)
self.bottleneck_data = None
self.dann_loss = None
def set_progress(self, value):
self.dann_head.set_progress(value)
def forward(self, x):
"""Sequentially pass `x` trough model`s `encoder` and `decoder` (return logits!)"""
x = self.net.encoder(x)
self.bottleneck_data = x
x = self.net.decoder(x)
return x
def calc_dann_loss(self, y_pred, y_true):
self.dann_loss = self.dann_head(self.bottleneck_data, y_true)
return self.dann_loss
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment