@@ -56,3 +56,81 @@ def sigmoid_focal_loss(
56
56
f"Invalid Value for arg 'reduction': '{ reduction } \n Supported reduction modes: 'none', 'mean', 'sum'"
57
57
)
58
58
return loss
59
+
60
+ def softmax_focal_loss (
61
+ inputs : torch .Tensor ,
62
+ targets : torch .Tensor ,
63
+ alpha : float = 0.25 ,
64
+ gamma : float = 2 ,
65
+ reduction : str = "none" ,
66
+ ) -> torch .Tensor :
67
+ """
68
+ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
69
+
70
+ Args:
71
+ inputs (Tensor): A float tensor of arbitrary shape.
72
+ The predictions for each example. Softmax() is applied on this tensor
73
+ to convert the raw logits to class probabilities. Expected shape is
74
+ (N, C, *).
75
+ targets (Tensor): Must be a long tensor similar to the one expected by
76
+ PyTorch's CrossEntropyLoss.
77
+ https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
78
+ The class dimension is expected to be absent, and each
79
+ element is the class value in the range [0, C).
80
+ alpha (float): Weighting factor in range (0,1) to balance
81
+ positive vs negative examples or -1 for ignore. Default: ``0.25``.
82
+ gamma (float): Exponent of the modulating factor (1 - p_t) to
83
+ balance easy vs hard examples. Default: ``2``.
84
+ reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
85
+ ``'none'``: No reduction will be applied to the output.
86
+ ``'mean'``: The output will be averaged.
87
+ ``'sum'``: The output will be summed.
88
+ ``'instance-sum-batch-mean'``: The output will be summed for each
89
+ value in the batch, and then averaged across the entire
90
+ batch. Default: ``'none'``.
91
+ Returns:
92
+ Loss tensor with the reduction option applied.
93
+ """
94
+ # Adapted from this version by Thomas V.
95
+ # https://discuss.pytorch.org/t/focal-loss-for-imbalanced-multi-class-classification-in-pytorch/61289/2
96
+ # Referenced from this github issue:
97
+ # https://github.com/pytorch/vision/issues/3250
98
+ if not torch .jit .is_scripting () and not torch .jit .is_tracing ():
99
+ _log_api_usage_once (softmax_focal_loss )
100
+
101
+ assert targets .dtype == torch .long , f"Expected a long tensor for 'targets', but got { targets .dtype } "
102
+
103
+ logits = inputs
104
+ ce_loss = nn .functional .cross_entropy (logits , targets , reduction = 'none' )
105
+ # Instead of computing inputs.softmax(dim=1), we use the exponentiated
106
+ # negative log of the cross entropy loss.
107
+ #
108
+ # Why does this work?
109
+ # Since this is a multi-class setting, only one class is active. The
110
+ # probability of that class is 1, and the rest are all 0.
111
+ #
112
+ # Cross Entropy Loss computes:
113
+ # pt = softmax(...)
114
+ # loss = -1.0 * log(pt)
115
+ #
116
+ # Hence, exp(-loss) == pt
117
+ #
118
+ # This trick works only if the targets is a long tensor. If it's a float
119
+ # tensor, then each each value is a probability, and we'd need to divide
120
+ # the result of cross entropy loss by the probability, and hence would
121
+ # need to compute the softmax manually anyway. We don't implement that
122
+ # here for brevity, but this code can be extended for such a use-case.
123
+ pt = torch .exp (- ce_loss )
124
+ focal_loss = alpha * ((1 - pt ) ** gamma ) * ce_loss
125
+ if reduction == 'none' :
126
+ return focal_loss
127
+ elif reduction == 'sum' :
128
+ return focal_loss .sum ()
129
+ elif reduction == 'mean' :
130
+ return focal_loss .mean ()
131
+ elif reduction == 'instance-sum-batch-mean' :
132
+ return focal_loss .sum () / logits .size (0 )
133
+ else :
134
+ raise ValueError (
135
+ f"Invalid Value for arg 'reduction': '{ reduction } \n Supported reduction modes: 'none', 'mean', 'sum', 'instance-sum-batch-mean'"
136
+ )
0 commit comments