# -*- coding: utf-8 -*-
# file: FocalLoss.py
# time: 14:16 2022/12/23
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2021. All Rights Reserved.
from typing import List
import torch
import torch.nn.functional as F
[docs]
class FocalLoss(torch.nn.Module):
# borrowed from Flair
"""
Focal loss(https://arxiv.org/pdf/1708.02002.pdf)
Shape:
- input: (N, C)
- target: (N)
- Output: Scalar loss
Examples:
>>> loss = FocalLoss(gamma=2, alpha=[1.0]*7)
>>> input = torch.randn(3, 7, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(7)
>>> output = loss(input, target)
>>> output.backward()
"""
def __init__(self, gamma=0, alpha: List[float] = None, reduction="none"):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
if alpha is not None:
self.alpha = torch.FloatTensor(alpha)
self.reduction = reduction
[docs]
def forward(self, input, target):
# [N, 1]
target = target.unsqueeze(-1)
# [N, C]
pt = F.softmax(input, dim=-1)
logpt = F.log_softmax(input, dim=-1)
# [N]
pt = pt.gather(1, target).squeeze(-1)
logpt = logpt.gather(1, target).squeeze(-1)
if self.alpha is not None:
# [N] at[i] = alpha[target[i]]
at = self.alpha.gather(0, target.squeeze(-1))
logpt = logpt * at
loss = -1 * (1 - pt) ** self.gamma * logpt
if self.reduction == "none":
return loss
if self.reduction == "mean":
return loss.mean()
return loss.sum()
@staticmethod
[docs]
def convert_binary_pred_to_two_dimension(x, is_logits=True):
"""
Args:
x: (*): (log) prob of some instance has label 1
is_logits: if True, x represents log prob; otherwhise presents prob
Returns:
y: (*, 2), where y[*, 1] == log prob of some instance has label 0,
y[*, 0] = log prob of some instance has label 1
"""
probs = torch.sigmoid(x) if is_logits else x
probs = probs.unsqueeze(-1)
probs = torch.cat([1 - probs, probs], dim=-1)
logprob = torch.log(probs + 1e-4) # 1e-4 to prevent being rounded to 0 in fp16
return logprob
[docs]
def __str__(self):
return f"Focal Loss gamma:{self.gamma}"
[docs]
def __repr__(self):
return str(self)