pyabsa.networks.losses.FocalLoss
Module Contents
Classes
Focal loss(https://arxiv.org/pdf/1708.02002.pdf) |
- class pyabsa.networks.losses.FocalLoss.FocalLoss(gamma=0, alpha: List[float] = None, reduction='none')[source]
Bases:
torch.nn.Module
Focal loss(https://arxiv.org/pdf/1708.02002.pdf) Shape:
input: (N, C)
target: (N)
Output: Scalar loss
Examples
>>> loss = FocalLoss(gamma=2, alpha=[1.0]*7) >>> input = torch.randn(3, 7, requires_grad=True) >>> target = torch.empty(3, dtype=torch.long).random_(7) >>> output = loss(input, target) >>> output.backward()
- static convert_binary_pred_to_two_dimension(x, is_logits=True)[source]
- Parameters:
x – (*): (log) prob of some instance has label 1
is_logits – if True, x represents log prob; otherwhise presents prob
- Returns:
- (, 2), where y[, 1] == log prob of some instance has label 0,
y[*, 0] = log prob of some instance has label 1
- Return type:
y