Skip to content

Commit 1a28917

Browse files
committed
Enhanced AsymmetricUnifiedFocalLoss with Sigmoid/Softmax
Signed-off-by: ytl0623 <[email protected]>
1 parent 83d2318 commit 1a28917

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,22 @@ def __init__(
3939
gamma: float = 0.75,
4040
epsilon: float = 1e-7,
4141
reduction: LossReduction | str = LossReduction.MEAN,
42+
include_background: bool = True,
4243
) -> None:
4344
"""
4445
Args:
4546
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
4647
delta : weight of the background. Defaults to 0.7.
4748
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
4849
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
50+
include_background: whether to include background class in loss calculation. Defaults to True.
4951
"""
5052
super().__init__(reduction=LossReduction(reduction).value)
5153
self.to_onehot_y = to_onehot_y
5254
self.delta = delta
5355
self.gamma = gamma
5456
self.epsilon = epsilon
57+
self.include_background = include_background
5558

5659
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
5760
n_pred_ch = y_pred.shape[1]
@@ -79,6 +82,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7982
back_dice = 1 - dice_class[:, 0]
8083
fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma)
8184

85+
if not self.include_background:
86+
back_dice = back_dice * 0.0
87+
8288
# Average class scores
8389
loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1))
8490
return loss
@@ -103,19 +109,22 @@ def __init__(
103109
gamma: float = 2,
104110
epsilon: float = 1e-7,
105111
reduction: LossReduction | str = LossReduction.MEAN,
112+
include_background: bool = True,
106113
):
107114
"""
108115
Args:
109116
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
110117
delta : weight of the background. Defaults to 0.7.
111118
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
112119
epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
120+
include_background: whether to include background class in loss calculation. Defaults to True.
113121
"""
114122
super().__init__(reduction=LossReduction(reduction).value)
115123
self.to_onehot_y = to_onehot_y
116124
self.delta = delta
117125
self.gamma = gamma
118126
self.epsilon = epsilon
127+
self.include_background = include_background
119128

120129
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
121130
n_pred_ch = y_pred.shape[1]
@@ -138,6 +147,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
138147
fore_ce = cross_entropy[:, 1]
139148
fore_ce = self.delta * fore_ce
140149

150+
if not self.include_background:
151+
back_ce = back_ce * 0.0
152+
141153
loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1))
142154
return loss
143155

0 commit comments

Comments
 (0)