Skip to content

Commit 4d41785

Browse files
committed
Add softmax_focal_loss() to allow multi-class focal loss
In image segmentation tasks, focal loss is useful when trying to classify an image pixel as one of N classes. Unfortunately, sigmoid_focal_loss() isn't useful in such cases. I found that other have been asking for this as well here pytorch#3250 so I decided to submit a PR for the same.
1 parent 8324c48 commit 4d41785

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

torchvision/ops/focal_loss.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,81 @@ def sigmoid_focal_loss(
5656
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
5757
)
5858
return loss
59+
60+
def softmax_focal_loss(
61+
inputs: torch.Tensor,
62+
targets: torch.Tensor,
63+
alpha: float = 0.25,
64+
gamma: float = 2,
65+
reduction: str = "none",
66+
) -> torch.Tensor:
67+
"""
68+
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
69+
70+
Args:
71+
inputs (Tensor): A float tensor of arbitrary shape.
72+
The predictions for each example. Softmax() is applied on this tensor
73+
to convert the raw logits to class probabilities. Expected shape is
74+
(N, C, *).
75+
targets (Tensor): Must be a long tensor similar to the one expected by
76+
PyTorch's CrossEntropyLoss.
77+
https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
78+
The class dimension is expected to be absent, and each
79+
element is the class value in the range [0, C).
80+
alpha (float): Weighting factor in range (0,1) to balance
81+
positive vs negative examples or -1 for ignore. Default: ``0.25``.
82+
gamma (float): Exponent of the modulating factor (1 - p_t) to
83+
balance easy vs hard examples. Default: ``2``.
84+
reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
85+
``'none'``: No reduction will be applied to the output.
86+
``'mean'``: The output will be averaged.
87+
``'sum'``: The output will be summed.
88+
``'instance-sum-batch-mean'``: The output will be summed for each
89+
value in the batch, and then averaged across the entire
90+
batch. Default: ``'none'``.
91+
Returns:
92+
Loss tensor with the reduction option applied.
93+
"""
94+
# Adapted from this version by Thomas V.
95+
# https://discuss.pytorch.org/t/focal-loss-for-imbalanced-multi-class-classification-in-pytorch/61289/2
96+
# Referenced from this github issue:
97+
# https://github.com/pytorch/vision/issues/3250
98+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
99+
_log_api_usage_once(softmax_focal_loss)
100+
101+
assert targets.dtype == torch.long, f"Expected a long tensor for 'targets', but got {targets.dtype}"
102+
103+
logits = inputs
104+
ce_loss = nn.functional.cross_entropy(logits, targets, reduction='none')
105+
# Instead of computing inputs.softmax(dim=1), we use the exponentiated
106+
# negative log of the cross entropy loss.
107+
#
108+
# Why does this work?
109+
# Since this is a multi-class setting, only one class is active. The
110+
# probability of that class is 1, and the rest are all 0.
111+
#
112+
# Cross Entropy Loss computes:
113+
# pt = softmax(...)
114+
# loss = -1.0 * log(pt)
115+
#
116+
# Hence, exp(-loss) == pt
117+
#
118+
# This trick works only if the targets is a long tensor. If it's a float
119+
# tensor, then each each value is a probability, and we'd need to divide
120+
# the result of cross entropy loss by the probability, and hence would
121+
# need to compute the softmax manually anyway. We don't implement that
122+
# here for brevity, but this code can be extended for such a use-case.
123+
pt = torch.exp(-ce_loss)
124+
focal_loss = alpha * ((1 - pt) ** gamma) * ce_loss
125+
if reduction == 'none':
126+
return focal_loss
127+
elif reduction == 'sum':
128+
return focal_loss.sum()
129+
elif reduction == 'mean':
130+
return focal_loss.mean()
131+
elif reduction == 'instance-sum-batch-mean':
132+
return focal_loss.sum() / logits.size(0)
133+
else:
134+
raise ValueError(
135+
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum', 'instance-sum-batch-mean'"
136+
)

0 commit comments

Comments
 (0)