Skip to content

Commit 83c100c

Browse files
committed
Add softmax_focal_loss() to allow multi-class focal loss
In image segmentation tasks, focal loss is useful when trying to classify an image pixel as one of N classes. Unfortunately, `sigmoid_focal_loss()` isn't useful in such cases. I found that other have been asking for this as well here pytorch#3250 so I decided to submit a PR for the same.
1 parent 8324c48 commit 83c100c

File tree

1 file changed

+93
-0
lines changed

1 file changed

+93
-0
lines changed

torchvision/ops/focal_loss.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import torch.nn.functional as F
3+
from typing import Optional
34

45
from ..utils import _log_api_usage_once
56

@@ -56,3 +57,95 @@ def sigmoid_focal_loss(
5657
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
5758
)
5859
return loss
60+
61+
def softmax_focal_loss(
62+
inputs: torch.Tensor,
63+
targets: torch.Tensor,
64+
alpha: Optional[torch.Tensor] = None,
65+
gamma: float = 2,
66+
eps: float = 1e-6,
67+
reduction: str = "none",
68+
) -> torch.Tensor:
69+
"""
70+
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
71+
72+
Args:
73+
inputs (Tensor): A float tensor of arbitrary shape.
74+
The predictions for each example. Softmax() is applied on this tensor
75+
to convert the raw logits to class probabilities. Expected shape is
76+
(N, C, *).
77+
targets (Tensor): Must be a long tensor similar to the one expected by
78+
PyTorch's CrossEntropyLoss.
79+
https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
80+
The class dimension is expected to be absent, and each
81+
element is the class value in the range [0, C).
82+
alpha (Tensor[float]): Weighting factor in range (0,1) to balance
83+
positive vs negative examples or None for no weighting. The elements
84+
of this alpha should sum up to 1.0. Default: ``None``.
85+
gamma (float): Exponent of the modulating factor (1 - p_t) to
86+
balance easy vs hard examples. Default: ``2``.
87+
eps (float): Small value to check if the sum of elements in alpha adds
88+
up to 1.0.
89+
reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
90+
``'none'``: No reduction will be applied to the output.
91+
``'mean'``: The output will be averaged.
92+
``'sum'``: The output will be summed.
93+
``'instance-sum-batch-mean'``: The output will be summed for each
94+
value in the batch, and then averaged across the entire
95+
batch. Default: ``'none'``.
96+
Returns:
97+
Loss tensor with the reduction option applied.
98+
"""
99+
# Adapted from this version by Thomas V.
100+
# https://discuss.pytorch.org/t/focal-loss-for-imbalanced-multi-class-classification-in-pytorch/61289/2
101+
# Referenced from this github issue:
102+
# https://github.com/pytorch/vision/issues/3250
103+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
104+
_log_api_usage_once(softmax_focal_loss)
105+
106+
assert targets.dtype == torch.long, f"Expected a long tensor for 'targets', but got {targets.dtype}"
107+
108+
logits = inputs
109+
weight = None
110+
if alpha is not None:
111+
num_classes = logits.size(1)
112+
assert isinstance(alpha, torch.Tensor), f"Expected alpha to be torch.Tensor, got {type(alpha)}"
113+
assert alpha.size(0) == num_classes, (
114+
f"Expected alpha (weights) to have {num_classes} elements, but got {alpha.size(0)} elements"
115+
)
116+
assert abs(alpha.sum() - 1.0) <= eps, (
117+
f"Expected elements of alpha to sum 1.0, instead they sum to {alpha.sum()}"
118+
)
119+
weight = alpha
120+
121+
ce_loss = nn.functional.cross_entropy(logits, targets, reduction='none')
122+
_ce_loss = nn.functional.cross_entropy(logits, targets, weight=weight, reduction='none')
123+
# Instead of computing inputs.softmax(dim=1), we use the exponentiated
124+
# negative log of the cross entropy loss.
125+
#
126+
# Why does this work?
127+
# Since this is a multi-class setting, only one class is active. The
128+
# probability of that class is 1, and the rest are all 0.
129+
#
130+
# Cross Entropy Loss computes:
131+
# pt = softmax(...)
132+
# loss = -1.0 * log(pt)
133+
#
134+
# Hence, exp(-loss) == pt
135+
#
136+
# This method works only if the targets is a long tensor, hence we check
137+
# that with an assertion earlier.
138+
pt = torch.exp(-_ce_loss)
139+
focal_loss = ((1 - pt) ** gamma) * ce_loss
140+
if reduction == 'none':
141+
return focal_loss
142+
elif reduction == 'sum':
143+
return focal_loss.sum()
144+
elif reduction == 'mean':
145+
return focal_loss.mean()
146+
elif reduction == 'instance-sum-batch-mean':
147+
return focal_loss.sum() / logits.size(0)
148+
else:
149+
raise ValueError(
150+
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum', 'instance-sum-batch-mean'"
151+
)

0 commit comments

Comments
 (0)