@@ -54,7 +54,7 @@ def __init__(
5454 self .delta = delta
5555 self .gamma = gamma
5656 self .epsilon = epsilon
57- self .include_background = include_background
57+ self .include_background : bool = include_background
5858
5959 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
6060 n_pred_ch = y_pred .shape [1 ]
@@ -77,6 +77,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7777 fn = torch .sum (y_true * (1 - y_pred ), dim = axis )
7878 fp = torch .sum ((1 - y_true ) * y_pred , dim = axis )
7979 dice_class = (tp + self .epsilon ) / (tp + self .delta * fn + (1 - self .delta ) * fp + self .epsilon )
80+ dice_class = torch .clamp (dice_class , self .epsilon , 1.0 - self .epsilon )
8081
8182 # Calculate losses separately for each class, enhancing both classes
8283 back_dice = 1 - dice_class [:, 0 :1 ]
@@ -126,7 +127,7 @@ def __init__(
126127 self .delta = delta
127128 self .gamma = gamma
128129 self .epsilon = epsilon
129- self .include_background = include_background
130+ self .include_background : bool = include_background
130131
131132 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
132133 n_pred_ch = y_pred .shape [1 ]
@@ -154,7 +155,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
154155
155156 all_ce = torch .cat ([back_ce , fore_ce ], dim = 1 )
156157
157- loss = torch .mean (torch . sum ( all_ce , dim = 1 ) )
158+ loss = torch .mean (all_ce )
158159 return loss
159160
160161
@@ -184,11 +185,11 @@ def __init__(
184185 """
185186 Args:
186187 to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
187- num_classes : number of classes, it only supports 2 now. Defaults to 2.
188+ num_classes : number of classes. Defaults to 2.
189+ weight : weight for combining focal loss and focal tversky loss. Defaults to 0.5.
190+ gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
188191 delta : weight of the background. Defaults to 0.7.
189- gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
190- epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
191- weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
192+ reduction : reduction mode for the loss. Defaults to LossReduction.MEAN.
192193 include_background : whether to include the background class in loss calculation. Defaults to True.
193194 use_softmax: whether to use softmax to transform the original logits into probabilities.
194195 If True, softmax is used. If False, sigmoid is used. Defaults to False.
@@ -208,12 +209,20 @@ def __init__(
208209 self .delta = delta
209210 self .weight : float = weight
210211 self .asy_focal_loss = AsymmetricFocalLoss (
211- to_onehot_y = self .to_onehot_y , gamma = self .gamma , delta = self .delta , include_background = self .include_background
212+ to_onehot_y = self .to_onehot_y ,
213+ gamma = self .gamma ,
214+ delta = self .delta ,
215+ include_background = self .include_background ,
216+ reduction = LossReduction .NONE ,
212217 )
213218 self .asy_focal_tversky_loss = AsymmetricFocalTverskyLoss (
214- to_onehot_y = self .to_onehot_y , gamma = self .gamma , delta = self .delta , include_background = self .include_background
219+ to_onehot_y = self .to_onehot_y ,
220+ gamma = self .gamma ,
221+ delta = self .delta ,
222+ include_background = self .include_background ,
223+ reduction = LossReduction .NONE ,
215224 )
216- self .include_background = include_background
225+ self .include_background : bool = include_background
217226 self .use_softmax = use_softmax
218227
219228 # TODO: Implement this function to support multiple classes segmentation
@@ -240,10 +249,15 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
240249 raise ValueError (f"input shape must be 4 or 5, but got { y_pred .shape } " )
241250
242251 if y_pred .shape [1 ] == 1 :
243- y_pred = one_hot (y_pred , num_classes = self .num_classes )
244- y_true = one_hot (y_true , num_classes = self .num_classes )
245-
246- if torch .max (y_true ) != self .num_classes - 1 :
252+ if self .num_classes != 2 :
253+ raise ValueError (
254+ f"Single-channel input only supported for binary (num_classes=2), got { self .num_classes } "
255+ )
256+ y_pred = torch .cat ([torch .zeros_like (y_pred ), y_pred ], dim = 1 )
257+ if y_true .shape [1 ] == 1 :
258+ y_true = one_hot (y_true , num_classes = self .num_classes )
259+
260+ if y_true .shape [1 ] != self .num_classes and torch .max (y_true ) > self .num_classes - 1 :
247261 raise ValueError (f"Please make sure the number of classes is { self .num_classes - 1 } " )
248262
249263 n_pred_ch = y_pred .shape [1 ]
0 commit comments