Commit 5b0d0410 authored by IlyaOvodov's avatar IlyaOvodov

from_dict transform, CompositeLoss, DANN

parent b59c3474
from .params import AttrDict from .params import AttrDict
import .params
import .pytorch
import .ignite
\ No newline at end of file
from .transforms import from_dict
\ No newline at end of file
def from_dict(key):
'''
Creates transform function extracting value from dict if input is dict
:param key: string
:return: functions: x -> x[key] if x is dict, x otherwise
'''
def call(data):
if isinstance(data, dict):
return data[key]
elif isinstance(data, tuple):
return tuple(call(di) for di in data)
else:
return data
return call
\ No newline at end of file
from .data import CachedDataSet, BatchThreadingDataLoader, ThreadingDataLoader from .data import CachedDataSet, BatchThreadingDataLoader, ThreadingDataLoader
from .losses import MeanLoss from .losses import CompositeLoss, MeanLoss
from .modules import ReverseLayerF, DANN_module, Dann_Head, DannEncDecNet
from .utils import set_reproducibility from .utils import set_reproducibility
from .utils import reproducibility_worker_init_fn from .utils import reproducibility_worker_init_fn
from .composite_loss import CompositeLoss
from .mean_loss import MeanLoss from .mean_loss import MeanLoss
import torch
class CompositeLoss(torch.nn.modules.loss._Loss):
def __init__(self, loss_funcs):
super(CompositeLoss, self).__init__()
self.loss_funcs = loss_funcs
self.loss_vals = [None] * (len(self.loss_funcs) + 1)
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
self.loss_vals = [loss_fn(y_pred, y_true) for (loss_fn, _,) in self.loss_funcs]
res = sum([w * self.loss_vals[i] for i, (_, w,) in enumerate(self.loss_funcs)])
self.loss_vals.append(res)
return res
def __len__(self):
return len(self.loss_funcs)
def get_val(self, index):
def call(*kargs, **kwargs):
return self.loss_vals[index]
return call
from .dann import ReverseLayerF, DANN_module, Dann_Head, DannEncDecNet
\ No newline at end of file
'''
DANN module. See https://arxiv.org/abs/1505.07818, https://arxiv.org/abs/1409.7495
'''
import torch
import torch.nn as nn
from torch.autograd import Function
class ReverseLayerF(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.alpha
return output, None
class DANN_module(nn.Module):
def __init__(self, gamma = 10., lambda_max = 1., **kwargs):
super(DANN_module, self).__init__()
self.gamma = gamma
self.lambda_max = lambda_max
self.progress = nn.Parameter(torch.Tensor([0,]), requires_grad = False) # must be updated from outside
def set_progress(self, value):
self.progress.data = torch.Tensor([value,]).to(self.progress.device)
def forward(self, input_data):
lambda_p = self.lambda_max * (2. / (1. + torch.exp(-self.gamma * self.progress.data)) - 1)
reverse_feature = ReverseLayerF.apply(input_data, lambda_p)
return reverse_feature
class Dann_Head(nn.Module):
def __init__(self, input_dims, num_classes, **kwargs):
super(Dann_Head, self).__init__()
self.input_dims = input_dims
self.pooling_depth = 100
self.dann_module = DANN_module(**kwargs)
self.pooling_modules = [nn.Sequential(
nn.Conv2d(dim, self.pooling_depth, 3),
nn.AdaptiveMaxPool2d(1)
)
for dim in input_dims]
for i, m in enumerate(self.pooling_modules):
self.add_module("pooling_module"+str(i), m)
self.domain_classifier = nn.Sequential()
self.domain_classifier.add_module('dann_fc1', nn.Linear(self.pooling_depth*len(input_dims), 100))
self.domain_classifier.add_module('dann_bn1', nn.BatchNorm1d(100)) # - dows not work for batch = 1
#self.domain_classifier.add_module('dann_bn1', nn.GroupNorm(5, 100)) # - dows not work for batch = 1
self.domain_classifier.add_module('dann_relu1', nn.ReLU(True))
self.domain_classifier.add_module('dann_fc2', nn.Linear(100, num_classes))
self.domain_classifier.add_module('dann_softmax', nn.LogSoftmax(dim=1))
self.loss = nn.NLLLoss()
def set_progress(self, value):
self.dann_module.set_progress(value)
def forward(self, inputs, y_true: torch.Tensor):
features = []
for i, x in enumerate(inputs[:len(self.input_dims)]):
assert x.shape[1] == self.input_dims[i]
x = self.dann_module(x)
x = self.pooling_modules[i](x)
features.append(x)
feature = torch.cat([f.flatten(start_dim=1) for f in features], dim=1)
domain_output = self.domain_classifier(feature)
loss = self.loss(domain_output, y_true.long())
return loss
class DannEncDecNet(nn.Module):
def __init__(self, base_net, input_dim, num_classes, **enc_dec_params):
super(DannEncDecNet, self).__init__()
self.net = base_net
self.dann_head = Dann_Head(input_dim, num_classes, **enc_dec_params)
self.bottleneck_data = None
self.dann_loss = None
def set_progress(self, value):
self.dann_head.set_progress(value)
def forward(self, x):
"""Sequentially pass `x` trough model`s `encoder` and `decoder` (return logits!)"""
x = self.net.encoder(x)
self.bottleneck_data = x
x = self.net.decoder(x)
return x
def calc_dann_loss(self, y_pred, y_true):
self.dann_loss = self.dann_head(self.bottleneck_data, y_true)
return self.dann_loss
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment