2222
2323class AsymmetricFocalTverskyLoss (_Loss ):
2424 """
25- AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class .
25+ AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss that focuses on foreground classes .
2626
27- Actually, it's only supported for binary image segmentation now .
27+ Supports multi-class segmentation with optional background inclusion .
2828
2929 Reimplementation of the Asymmetric Focal Tversky Loss described in:
3030
@@ -61,7 +61,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6161
6262 if self .to_onehot_y :
6363 if n_pred_ch == 1 :
64- warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
64+ warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." , stacklevel = 2 )
6565 else :
6666 y_true = one_hot (y_true , num_classes = n_pred_ch )
6767
@@ -95,9 +95,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
9595
9696class AsymmetricFocalLoss (_Loss ):
9797 """
98- AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class .
98+ AsymmetricFocalLoss is a variant of Focal Loss that focuses on foreground classes .
9999
100- Actually, it's only supported for binary image segmentation now .
100+ Supports multi-class segmentation with optional background inclusion .
101101
102102 Reimplementation of the Asymmetric Focal Loss described in:
103103
@@ -134,7 +134,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
134134
135135 if self .to_onehot_y :
136136 if n_pred_ch == 1 :
137- warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
137+ warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." , stacklevel = 2 )
138138 else :
139139 y_true = one_hot (y_true , num_classes = n_pred_ch )
140140
@@ -161,9 +161,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
161161
162162class AsymmetricUnifiedFocalLoss (_Loss ):
163163 """
164- AsymmetricUnifiedFocalLoss is a variant of Focal Loss.
164+ AsymmetricUnifiedFocalLoss combines Asymmetric Focal Loss and Asymmetric Focal Tversky Loss.
165165
166- Actually, it's only supported for binary image segmentation now
166+ Supports multi-class segmentation with configurable activation (sigmoid/softmax) and optional background inclusion.
167167
168168 Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:
169169
@@ -201,6 +201,11 @@ def __init__(
201201 >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)
202202 >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
203203 >>> fl(pred, grnd)
204+ >>> # Multiclass example with 3 classes
205+ >>> pred_mc = torch.randn((1,3,32,32), dtype=torch.float32)
206+ >>> grnd_mc = torch.randint(0, 3, (1,1,32,32), dtype=torch.int64)
207+ >>> fl_mc = AsymmetricUnifiedFocalLoss(to_onehot_y=True, num_classes=3, use_softmax=True)
208+ >>> fl_mc(pred_mc, grnd_mc)
204209 """
205210 super ().__init__ (reduction = LossReduction (reduction ).value )
206211 self .to_onehot_y = to_onehot_y
@@ -225,16 +230,13 @@ def __init__(
225230 reduction = LossReduction .NONE ,
226231 )
227232
228- # TODO: Implement this function to support multiple classes segmentation
229233 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
230234 """
231235 Args:
232236 y_pred : the shape should be BNH[WD], where N is the number of classes.
233- It only supports binary segmentation.
234237 The input should be the original logits since it will be transformed by
235238 a sigmoid or softmax in the forward function.
236239 y_true : the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
237- It only supports binary segmentation.
238240
239241 Raises:
240242 ValueError: When input and target are different shape
@@ -264,11 +266,12 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
264266 n_pred_ch = y_pred .shape [1 ]
265267 if self .to_onehot_y :
266268 if n_pred_ch == 1 :
267- warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
269+ warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." , stacklevel = 2 )
268270 else :
269271 y_true = one_hot (y_true , num_classes = n_pred_ch )
270272
271273 if y_pred .shape [1 ] == 1 :
274+ warnings .warn ("single channel prediction, augmenting with background channel." , stacklevel = 2 )
272275 y_pred_sigmoid = torch .sigmoid (y_pred .float ())
273276 y_pred = torch .cat ([1 - y_pred_sigmoid , y_pred_sigmoid ], dim = 1 )
274277
@@ -278,7 +281,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
278281 if self .use_softmax :
279282 y_pred = torch .softmax (y_pred .float (), dim = 1 )
280283 else :
281- y_pred = torch . sigmoid ( y_pred .float () )
284+ y_pred = y_pred .float ()
282285
283286 asy_focal_loss = self .asy_focal_loss (y_pred , y_true )
284287 asy_focal_tversky_loss = self .asy_focal_tversky_loss (y_pred , y_true )
0 commit comments