pyabsa.networks.losses.FocalLoss

Module Contents

Classes

FocalLoss

Focal loss(https://arxiv.org/pdf/1708.02002.pdf)

class pyabsa.networks.losses.FocalLoss.FocalLoss(gamma=0, alpha: List[float] = None, reduction='none')[source]

Bases: torch.nn.Module

Focal loss(https://arxiv.org/pdf/1708.02002.pdf) Shape:

  • input: (N, C)

  • target: (N)

  • Output: Scalar loss

Examples

>>> loss = FocalLoss(gamma=2, alpha=[1.0]*7)
>>> input = torch.randn(3, 7, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(7)
>>> output = loss(input, target)
>>> output.backward()
forward(input, target)[source]
static convert_binary_pred_to_two_dimension(x, is_logits=True)[source]
Parameters:
  • x – (*): (log) prob of some instance has label 1

  • is_logits – if True, x represents log prob; otherwhise presents prob

Returns:

(, 2), where y[, 1] == log prob of some instance has label 0,

y[*, 0] = log prob of some instance has label 1

Return type:

y

__str__()[source]

Return str(self).

__repr__()[source]

Return repr(self).