|
1 | 1 | import torch
|
2 | 2 | import torch.nn.functional as F
|
| 3 | +from typing import Optional |
3 | 4 |
|
4 | 5 | from ..utils import _log_api_usage_once
|
5 | 6 |
|
@@ -56,3 +57,95 @@ def sigmoid_focal_loss(
|
56 | 57 | f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
|
57 | 58 | )
|
58 | 59 | return loss
|
| 60 | + |
| 61 | +def softmax_focal_loss( |
| 62 | + inputs: torch.Tensor, |
| 63 | + targets: torch.Tensor, |
| 64 | + alpha: Optional[torch.Tensor] = None, |
| 65 | + gamma: float = 2, |
| 66 | + eps: float = 1e-6, |
| 67 | + reduction: str = "none", |
| 68 | +) -> torch.Tensor: |
| 69 | + """ |
| 70 | + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. |
| 71 | +
|
| 72 | + Args: |
| 73 | + inputs (Tensor): A float tensor of arbitrary shape. |
| 74 | + The predictions for each example. Softmax() is applied on this tensor |
| 75 | + to convert the raw logits to class probabilities. Expected shape is |
| 76 | + (N, C, *). |
| 77 | + targets (Tensor): Must be a long tensor similar to the one expected by |
| 78 | + PyTorch's CrossEntropyLoss. |
| 79 | + https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html |
| 80 | + The class dimension is expected to be absent, and each |
| 81 | + element is the class value in the range [0, C). |
| 82 | + alpha (Tensor[float]): Weighting factor in range (0,1) to balance |
| 83 | + positive vs negative examples or None for no weighting. The elements |
| 84 | + of this alpha should sum up to 1.0. Default: ``None``. |
| 85 | + gamma (float): Exponent of the modulating factor (1 - p_t) to |
| 86 | + balance easy vs hard examples. Default: ``2``. |
| 87 | + eps (float): Small value to check if the sum of elements in alpha adds |
| 88 | + up to 1.0. |
| 89 | + reduction (string): ``'none'`` | ``'mean'`` | ``'sum'`` |
| 90 | + ``'none'``: No reduction will be applied to the output. |
| 91 | + ``'mean'``: The output will be averaged. |
| 92 | + ``'sum'``: The output will be summed. |
| 93 | + ``'instance-sum-batch-mean'``: The output will be summed for each |
| 94 | + value in the batch, and then averaged across the entire |
| 95 | + batch. Default: ``'none'``. |
| 96 | + Returns: |
| 97 | + Loss tensor with the reduction option applied. |
| 98 | + """ |
| 99 | + # Adapted from this version by Thomas V. |
| 100 | + # https://discuss.pytorch.org/t/focal-loss-for-imbalanced-multi-class-classification-in-pytorch/61289/2 |
| 101 | + # Referenced from this github issue: |
| 102 | + # https://github.com/pytorch/vision/issues/3250 |
| 103 | + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
| 104 | + _log_api_usage_once(softmax_focal_loss) |
| 105 | + |
| 106 | + assert targets.dtype == torch.long, f"Expected a long tensor for 'targets', but got {targets.dtype}" |
| 107 | + |
| 108 | + logits = inputs |
| 109 | + weight = None |
| 110 | + if alpha is not None: |
| 111 | + num_classes = logits.size(1) |
| 112 | + assert isinstance(alpha, torch.Tensor), f"Expected alpha to be torch.Tensor, got {type(alpha)}" |
| 113 | + assert alpha.size(0) == num_classes, ( |
| 114 | + f"Expected alpha (weights) to have {num_classes} elements, but got {alpha.size(0)} elements" |
| 115 | + ) |
| 116 | + assert abs(alpha.sum() - 1.0) <= eps, ( |
| 117 | + f"Expected elements of alpha to sum 1.0, instead they sum to {alpha.sum()}" |
| 118 | + ) |
| 119 | + weight = alpha |
| 120 | + |
| 121 | + ce_loss = nn.functional.cross_entropy(logits, targets, reduction='none') |
| 122 | + _ce_loss = nn.functional.cross_entropy(logits, targets, weight=weight, reduction='none') |
| 123 | + # Instead of computing inputs.softmax(dim=1), we use the exponentiated |
| 124 | + # negative log of the cross entropy loss. |
| 125 | + # |
| 126 | + # Why does this work? |
| 127 | + # Since this is a multi-class setting, only one class is active. The |
| 128 | + # probability of that class is 1, and the rest are all 0. |
| 129 | + # |
| 130 | + # Cross Entropy Loss computes: |
| 131 | + # pt = softmax(...) |
| 132 | + # loss = -1.0 * log(pt) |
| 133 | + # |
| 134 | + # Hence, exp(-loss) == pt |
| 135 | + # |
| 136 | + # This method works only if the targets is a long tensor, hence we check |
| 137 | + # that with an assertion earlier. |
| 138 | + pt = torch.exp(-_ce_loss) |
| 139 | + focal_loss = ((1 - pt) ** gamma) * ce_loss |
| 140 | + if reduction == 'none': |
| 141 | + return focal_loss |
| 142 | + elif reduction == 'sum': |
| 143 | + return focal_loss.sum() |
| 144 | + elif reduction == 'mean': |
| 145 | + return focal_loss.mean() |
| 146 | + elif reduction == 'instance-sum-batch-mean': |
| 147 | + return focal_loss.sum() / logits.size(0) |
| 148 | + else: |
| 149 | + raise ValueError( |
| 150 | + f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum', 'instance-sum-batch-mean'" |
| 151 | + ) |
0 commit comments