Skip to content

Commit b46f089

Browse files
committed
fix using function twice.
Signed-off-by: ytl0623 <[email protected]>
1 parent b7a5013 commit b46f089

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222

2323
class AsymmetricFocalTverskyLoss(_Loss):
2424
"""
25-
AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
25+
AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss that focuses on foreground classes.
2626
27-
Actually, it's only supported for binary image segmentation now.
27+
Supports multi-class segmentation with optional background inclusion.
2828
2929
Reimplementation of the Asymmetric Focal Tversky Loss described in:
3030
@@ -61,7 +61,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6161

6262
if self.to_onehot_y:
6363
if n_pred_ch == 1:
64-
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
64+
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
6565
else:
6666
y_true = one_hot(y_true, num_classes=n_pred_ch)
6767

@@ -95,9 +95,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
9595

9696
class AsymmetricFocalLoss(_Loss):
9797
"""
98-
AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.
98+
AsymmetricFocalLoss is a variant of Focal Loss that focuses on foreground classes.
9999
100-
Actually, it's only supported for binary image segmentation now.
100+
Supports multi-class segmentation with optional background inclusion.
101101
102102
Reimplementation of the Asymmetric Focal Loss described in:
103103
@@ -134,7 +134,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
134134

135135
if self.to_onehot_y:
136136
if n_pred_ch == 1:
137-
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
137+
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
138138
else:
139139
y_true = one_hot(y_true, num_classes=n_pred_ch)
140140

@@ -161,9 +161,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
161161

162162
class AsymmetricUnifiedFocalLoss(_Loss):
163163
"""
164-
AsymmetricUnifiedFocalLoss is a variant of Focal Loss.
164+
AsymmetricUnifiedFocalLoss combines Asymmetric Focal Loss and Asymmetric Focal Tversky Loss.
165165
166-
Actually, it's only supported for binary image segmentation now
166+
Supports multi-class segmentation with configurable activation (sigmoid/softmax) and optional background inclusion.
167167
168168
Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:
169169
@@ -201,6 +201,11 @@ def __init__(
201201
>>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)
202202
>>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
203203
>>> fl(pred, grnd)
204+
>>> # Multiclass example with 3 classes
205+
>>> pred_mc = torch.randn((1,3,32,32), dtype=torch.float32)
206+
>>> grnd_mc = torch.randint(0, 3, (1,1,32,32), dtype=torch.int64)
207+
>>> fl_mc = AsymmetricUnifiedFocalLoss(to_onehot_y=True, num_classes=3, use_softmax=True)
208+
>>> fl_mc(pred_mc, grnd_mc)
204209
"""
205210
super().__init__(reduction=LossReduction(reduction).value)
206211
self.to_onehot_y = to_onehot_y
@@ -225,16 +230,13 @@ def __init__(
225230
reduction=LossReduction.NONE,
226231
)
227232

228-
# TODO: Implement this function to support multiple classes segmentation
229233
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
230234
"""
231235
Args:
232236
y_pred : the shape should be BNH[WD], where N is the number of classes.
233-
It only supports binary segmentation.
234237
The input should be the original logits since it will be transformed by
235238
a sigmoid or softmax in the forward function.
236239
y_true : the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
237-
It only supports binary segmentation.
238240
239241
Raises:
240242
ValueError: When input and target are different shape
@@ -264,11 +266,12 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
264266
n_pred_ch = y_pred.shape[1]
265267
if self.to_onehot_y:
266268
if n_pred_ch == 1:
267-
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
269+
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2)
268270
else:
269271
y_true = one_hot(y_true, num_classes=n_pred_ch)
270272

271273
if y_pred.shape[1] == 1:
274+
warnings.warn("single channel prediction, augmenting with background channel.", stacklevel=2)
272275
y_pred_sigmoid = torch.sigmoid(y_pred.float())
273276
y_pred = torch.cat([1 - y_pred_sigmoid, y_pred_sigmoid], dim=1)
274277

@@ -278,7 +281,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
278281
if self.use_softmax:
279282
y_pred = torch.softmax(y_pred.float(), dim=1)
280283
else:
281-
y_pred = torch.sigmoid(y_pred.float())
284+
y_pred = y_pred.float()
282285

283286
asy_focal_loss = self.asy_focal_loss(y_pred, y_true)
284287
asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true)

0 commit comments

Comments
 (0)