Commit 7be4d213 authored by IlyaOvodov's avatar IlyaOvodov

label smoothing

parent 71888884
from .data import CachedDataSet, BatchThreadingDataLoader, ThreadingDataLoader from .data import CachedDataSet, BatchThreadingDataLoader, ThreadingDataLoader
from .losses import SimpleLoss, CompositeLoss, MeanLoss, LabelSmoothingBCEWithLogitsLoss from .losses import SimpleLoss, CompositeLoss, MeanLoss, LabelSmoothingBCEWithLogitsLoss, PseudoLabelingBCELoss
from .modules import ReverseLayerF, DANN_module, Dann_Head, DannEncDecNet from .modules import ReverseLayerF, DANN_module, Dann_Head, DannEncDecNet
......
from .composite_loss import SimpleLoss, CompositeLoss from .composite_loss import SimpleLoss, CompositeLoss
from .mean_loss import MeanLoss from .mean_loss import MeanLoss
from .label_smoothing import LabelSmoothingBCEWithLogitsLoss from .label_smoothing import LabelSmoothingBCEWithLogitsLoss
from .pseudo_labeling import PseudoLabelingBCELoss
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
def LabelSmoothingBCEWithLogitsLoss(label_smoothing = 0, **kwargs): def LabelSmoothingBCEWithLogitsLoss(label_smoothing = 0.1, **kwargs):
def loss(y, y_gt): def loss(y, y_gt):
''' '''
s = 1/(1+exp(-x)) s = 1/(1+exp(-x))
......
import torch
import math
class PseudoLabelingBCELoss(torch.nn.modules.loss._Loss):
'''
'''
def __init__(self, confindence_thr = 0.2, **kwargs):
super(PseudoLabelingBCELoss, self).__init__()
self.base_loss = torch.nn.BCEWithLogitsLoss(reduction = 'none')
self.logit_thr = math.log((1-confindence_thr)/confindence_thr)
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
pseudo_labels = (y_pred.sign()+1)/2
pseudo_labels_mask = (y_pred.abs()>self.logit_thr).float()
pseudo_labels_cnt = pseudo_labels_mask.sum()
self.val = (pseudo_labels_mask*self.base_loss(y_pred, pseudo_labels)).mean()
if pseudo_labels_cnt:
self.val *= torch.Tensor([y_pred.shape]).prod()/pseudo_labels_cnt
return self.val
def __len__(self):
'''
returns number of individual channel losses
'''
return 0
def get_val(self):
'''
returns function to get last result
'''
def call(*kargs, **kwargs):
return self.val
return call
...@@ -153,6 +153,6 @@ def CreateCompositeLoss(loss_params: dict, eval_func=eval) -> torch.nn.modules.l ...@@ -153,6 +153,6 @@ def CreateCompositeLoss(loss_params: dict, eval_func=eval) -> torch.nn.modules.l
else: else:
loss_funcs = [] loss_funcs = []
for loss_param in loss_params: for loss_param in loss_params:
loss_i = CreateCompositeLoss(loss_param) loss_i = CreateCompositeLoss(loss_param, eval_func=eval_func)
loss_funcs.append((loss_i, loss_param.get('weight', 1.),)) loss_funcs.append((loss_i, loss_param.get('weight', 1.),))
return CompositeLoss(loss_funcs) return CompositeLoss(loss_funcs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment