pyabsa.networks.losses.FocalLoss

Classes

FocalLoss

Focal loss(https://arxiv.org/pdf/1708.02002.pdf)

Module Contents

class pyabsa.networks.losses.FocalLoss.FocalLoss(gamma=0, alpha: List[float] = None, reduction='none')

Bases: torch.nn.Module

Focal loss(https://arxiv.org/pdf/1708.02002.pdf) Shape:

  • input: (N, C)

  • target: (N)

  • Output: Scalar loss

Examples

>>> loss = FocalLoss(gamma=2, alpha=[1.0]*7)
>>> input = torch.randn(3, 7, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(7)
>>> output = loss(input, target)
>>> output.backward()
gamma = 0
alpha = None
reduction = 'none'
forward(input, target)
static convert_binary_pred_to_two_dimension(x, is_logits=True)
Parameters:
  • x – (*): (log) prob of some instance has label 1

  • is_logits – if True, x represents log prob; otherwhise presents prob

Returns:

(, 2), where y[, 1] == log prob of some instance has label 0,

y[*, 0] = log prob of some instance has label 1

Return type:

y

__str__()
__repr__()