@@ -163,16 +163,16 @@ def __init__(
163163 delta : float = 0.7 ,
164164 reduction : LossReduction | str = LossReduction .MEAN ,
165165 include_background : bool = True ,
166- use_softmax : bool = False
166+ use_softmax : bool = False ,
167167 ):
168168 """
169169 Args:
170170 to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
171171 num_classes : number of classes, it only supports 2 now. Defaults to 2.
172- weight : weight for combining focal loss and focal tversky loss. Defaults to 0.5.
173- gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
174172 delta : weight of the background. Defaults to 0.7.
175- reduction : reduction mode for the loss. Defaults to MEAN.
173+ gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
174+ epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
175+ weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
176176 include_background : whether to include the background class in loss calculation. Defaults to True.
177177 use_softmax: whether to use softmax to transform the original logits into probabilities.
178178 If True, softmax is used. If False, sigmoid is used. Defaults to False.
@@ -186,15 +186,15 @@ def __init__(
186186 >>> fl(pred, grnd)
187187 """
188188 super ().__init__ (reduction = LossReduction (reduction ).value )
189- if use_sigmoid and use_softmax :
190- raise ValueError ("use_sigmoid and use_softmax are mutually exclusive; only one can be True." )
191189 self .to_onehot_y = to_onehot_y
192190 self .num_classes = num_classes
193191 self .gamma = gamma
194192 self .delta = delta
195193 self .weight : float = weight
196- self .asy_focal_loss = AsymmetricFocalLoss (gamma = self .gamma , delta = self .delta )
197- self .asy_focal_tversky_loss = AsymmetricFocalTverskyLoss (gamma = self .gamma , delta = self .delta )
194+ self .asy_focal_loss = AsymmetricFocalLoss (to_onehot_y = self .to_onehot_y , gamma = self .gamma , delta = self .delta )
195+ self .asy_focal_tversky_loss = AsymmetricFocalTverskyLoss (
196+ to_onehot_y = self .to_onehot_y , gamma = self .gamma , delta = self .delta
197+ )
198198 self .include_background = include_background
199199 self .use_softmax = use_softmax
200200
@@ -205,8 +205,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
205205 y_pred : the shape should be BNH[WD], where N is the number of classes.
206206 It only supports binary segmentation.
207207 The input should be the original logits since it will be transformed by
208- a sigmoid in the forward function.
209- y_true : the shape should be BNH[WD], where N is the number of classes.
208+ a sigmoid or softmax in the forward function.
209+ y_true : the shape should be BNH[WD] or B1H[WD] , where N is the number of classes.
210210 It only supports binary segmentation.
211211
212212 Raises:
@@ -235,6 +235,19 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
235235 else :
236236 y_true = one_hot (y_true , num_classes = n_pred_ch )
237237
238+ if not self .include_background :
239+ if n_pred_ch == 1 :
240+ warnings .warn ("single channel prediction, `include_background=False` ignored." )
241+ else :
242+ # if skipping background, removing first channel
243+ y_pred = y_pred [:, 1 :]
244+ y_true = y_true [:, 1 :]
245+
246+ if self .use_softmax :
247+ y_pred = torch .softmax (y_pred .float (), dim = 1 )
248+ else :
249+ y_pred = torch .sigmoid (y_pred .float ())
250+
238251 asy_focal_loss = self .asy_focal_loss (y_pred , y_true )
239252 asy_focal_tversky_loss = self .asy_focal_tversky_loss (y_pred , y_true )
240253
0 commit comments