Skip to content

Commit 83d2318

Browse files
committed
Generalize AsymmetricUnifiedFocalLoss for align interface
Signed-off-by: ytl0623 <[email protected]>
1 parent 0a29b5e commit 83d2318

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -163,16 +163,16 @@ def __init__(
163163
delta: float = 0.7,
164164
reduction: LossReduction | str = LossReduction.MEAN,
165165
include_background: bool = True,
166-
use_softmax: bool = False
166+
use_softmax: bool = False,
167167
):
168168
"""
169169
Args:
170170
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
171171
num_classes : number of classes, it only supports 2 now. Defaults to 2.
172-
weight : weight for combining focal loss and focal tversky loss. Defaults to 0.5.
173-
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
174172
delta : weight of the background. Defaults to 0.7.
175-
reduction : reduction mode for the loss. Defaults to MEAN.
173+
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
174+
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
175+
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
176176
include_background : whether to include the background class in loss calculation. Defaults to True.
177177
use_softmax: whether to use softmax to transform the original logits into probabilities.
178178
If True, softmax is used. If False, sigmoid is used. Defaults to False.
@@ -186,15 +186,15 @@ def __init__(
186186
>>> fl(pred, grnd)
187187
"""
188188
super().__init__(reduction=LossReduction(reduction).value)
189-
if use_sigmoid and use_softmax:
190-
raise ValueError("use_sigmoid and use_softmax are mutually exclusive; only one can be True.")
191189
self.to_onehot_y = to_onehot_y
192190
self.num_classes = num_classes
193191
self.gamma = gamma
194192
self.delta = delta
195193
self.weight: float = weight
196-
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
197-
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)
194+
self.asy_focal_loss = AsymmetricFocalLoss(to_onehot_y=self.to_onehot_y, gamma=self.gamma, delta=self.delta)
195+
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
196+
to_onehot_y=self.to_onehot_y, gamma=self.gamma, delta=self.delta
197+
)
198198
self.include_background = include_background
199199
self.use_softmax = use_softmax
200200

@@ -205,8 +205,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
205205
y_pred : the shape should be BNH[WD], where N is the number of classes.
206206
It only supports binary segmentation.
207207
The input should be the original logits since it will be transformed by
208-
a sigmoid in the forward function.
209-
y_true : the shape should be BNH[WD], where N is the number of classes.
208+
a sigmoid or softmax in the forward function.
209+
y_true : the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
210210
It only supports binary segmentation.
211211
212212
Raises:
@@ -235,6 +235,19 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
235235
else:
236236
y_true = one_hot(y_true, num_classes=n_pred_ch)
237237

238+
if not self.include_background:
239+
if n_pred_ch == 1:
240+
warnings.warn("single channel prediction, `include_background=False` ignored.")
241+
else:
242+
# if skipping background, removing first channel
243+
y_pred = y_pred[:, 1:]
244+
y_true = y_true[:, 1:]
245+
246+
if self.use_softmax:
247+
y_pred = torch.softmax(y_pred.float(), dim=1)
248+
else:
249+
y_pred = torch.sigmoid(y_pred.float())
250+
238251
asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
239252
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)
240253

0 commit comments

Comments
 (0)