Commit 21dc6361 authored by IlyaOvodov's avatar IlyaOvodov

FocalLoss

parent 68d5cadd
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class DummyTimer:
......@@ -152,3 +154,50 @@ def save_model(model, params, rel_dir, filename):
dir_name = os.path.dirname(file_name)
os.makedirs(dir_name, exist_ok=True)
torch.save(model.state_dict(), file_name)
class FocalBceLoss(nn.Module):
def __init__(self, weight=1, gamma=2, logits=False, reduce=True):
super(FocalBceLoss, self).__init__()
self.weight = weight
self.gamma = gamma
self.logits = logits
self.reduce = reduce
def forward(self, inputs, targets):
if self.logits:
bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
else:
bce_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
pt = torch.exp(-bce_loss)
f_loss = self.weight * (1-pt)**self.gamma * bce_loss
if self.reduce:
return torch.mean(f_loss)
else:
return f_loss
class FocalCeLoss(nn.Module):
def __init__(self, weight=None, gamma=2, logits=False, reduce=True):
super(FocalCeLoss, self).__init__()
self.weight = weight
self.gamma = gamma
self.logits = logits
self.reduce = reduce
def forward(self, inputs, targets):
if self.logits:
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
else:
ce_loss = F.nll_loss(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
f_loss = (1 - pt) ** self.gamma * ce_loss
if self.weight is not None:
weights = torch.index_select(self.weight, 0, targets.view(-1)).view(targets.shape)
f_loss *= weights
if self.reduce:
return torch.mean(f_loss)
else:
return f_loss
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment