Skip to content

Commit 08bfb0e

Browse files
committed
minor fixes
Signed-off-by: ytl0623 <[email protected]>
1 parent 81af139 commit 08bfb0e

File tree

1 file changed

+28
-14
lines changed

1 file changed

+28
-14
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
self.delta = delta
5555
self.gamma = gamma
5656
self.epsilon = epsilon
57-
self.include_background = include_background
57+
self.include_background: bool = include_background
5858

5959
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6060
n_pred_ch = y_pred.shape[1]
@@ -77,6 +77,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7777
fn = torch.sum(y_true * (1 - y_pred), dim=axis)
7878
fp = torch.sum((1 - y_true) * y_pred, dim=axis)
7979
dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)
80+
dice_class = torch.clamp(dice_class, self.epsilon, 1.0 - self.epsilon)
8081

8182
# Calculate losses separately for each class, enhancing both classes
8283
back_dice = 1 - dice_class[:, 0:1]
@@ -126,7 +127,7 @@ def __init__(
126127
self.delta = delta
127128
self.gamma = gamma
128129
self.epsilon = epsilon
129-
self.include_background = include_background
130+
self.include_background: bool = include_background
130131

131132
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
132133
n_pred_ch = y_pred.shape[1]
@@ -154,7 +155,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
154155

155156
all_ce = torch.cat([back_ce, fore_ce], dim=1)
156157

157-
loss = torch.mean(torch.sum(all_ce, dim=1))
158+
loss = torch.mean(all_ce)
158159
return loss
159160

160161

@@ -184,11 +185,11 @@ def __init__(
184185
"""
185186
Args:
186187
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
187-
num_classes : number of classes, it only supports 2 now. Defaults to 2.
188+
num_classes : number of classes. Defaults to 2.
189+
weight : weight for combining focal loss and focal tversky loss. Defaults to 0.5.
190+
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
188191
delta : weight of the background. Defaults to 0.7.
189-
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
190-
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
191-
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
192+
reduction : reduction mode for the loss. Defaults to LossReduction.MEAN.
192193
include_background : whether to include the background class in loss calculation. Defaults to True.
193194
use_softmax: whether to use softmax to transform the original logits into probabilities.
194195
If True, softmax is used. If False, sigmoid is used. Defaults to False.
@@ -208,12 +209,20 @@ def __init__(
208209
self.delta = delta
209210
self.weight: float = weight
210211
self.asy_focal_loss = AsymmetricFocalLoss(
211-
to_onehot_y=self.to_onehot_y, gamma=self.gamma, delta=self.delta, include_background=self.include_background
212+
to_onehot_y=self.to_onehot_y,
213+
gamma=self.gamma,
214+
delta=self.delta,
215+
include_background=self.include_background,
216+
reduction=LossReduction.NONE,
212217
)
213218
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(
214-
to_onehot_y=self.to_onehot_y, gamma=self.gamma, delta=self.delta, include_background=self.include_background
219+
to_onehot_y=self.to_onehot_y,
220+
gamma=self.gamma,
221+
delta=self.delta,
222+
include_background=self.include_background,
223+
reduction=LossReduction.NONE,
215224
)
216-
self.include_background = include_background
225+
self.include_background: bool = include_background
217226
self.use_softmax = use_softmax
218227

219228
# TODO: Implement this function to support multiple classes segmentation
@@ -240,10 +249,15 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
240249
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")
241250

242251
if y_pred.shape[1] == 1:
243-
y_pred = one_hot(y_pred, num_classes=self.num_classes)
244-
y_true = one_hot(y_true, num_classes=self.num_classes)
245-
246-
if torch.max(y_true) != self.num_classes - 1:
252+
if self.num_classes != 2:
253+
raise ValueError(
254+
f"Single-channel input only supported for binary (num_classes=2), got {self.num_classes}"
255+
)
256+
y_pred = torch.cat([torch.zeros_like(y_pred), y_pred], dim=1)
257+
if y_true.shape[1] == 1:
258+
y_true = one_hot(y_true, num_classes=self.num_classes)
259+
260+
if y_true.shape[1] != self.num_classes and torch.max(y_true) > self.num_classes - 1:
247261
raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}")
248262

249263
n_pred_ch = y_pred.shape[1]

0 commit comments

Comments
 (0)