@@ -39,19 +39,22 @@ def __init__(
3939 gamma : float = 0.75 ,
4040 epsilon : float = 1e-7 ,
4141 reduction : LossReduction | str = LossReduction .MEAN ,
42+ include_background : bool = True ,
4243 ) -> None :
4344 """
4445 Args:
4546 to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
4647 delta : weight of the background. Defaults to 0.7.
4748 gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
4849 epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
50+ include_background: whether to include background class in loss calculation. Defaults to True.
4951 """
5052 super ().__init__ (reduction = LossReduction (reduction ).value )
5153 self .to_onehot_y = to_onehot_y
5254 self .delta = delta
5355 self .gamma = gamma
5456 self .epsilon = epsilon
57+ self .include_background = include_background
5558
5659 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
5760 n_pred_ch = y_pred .shape [1 ]
@@ -79,6 +82,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7982 back_dice = 1 - dice_class [:, 0 ]
8083 fore_dice = (1 - dice_class [:, 1 ]) * torch .pow (1 - dice_class [:, 1 ], - self .gamma )
8184
85+ if not self .include_background :
86+ back_dice = back_dice * 0.0
87+
8288 # Average class scores
8389 loss = torch .mean (torch .stack ([back_dice , fore_dice ], dim = - 1 ))
8490 return loss
@@ -103,19 +109,22 @@ def __init__(
103109 gamma : float = 2 ,
104110 epsilon : float = 1e-7 ,
105111 reduction : LossReduction | str = LossReduction .MEAN ,
112+ include_background : bool = True ,
106113 ):
107114 """
108115 Args:
109116 to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
110117 delta : weight of the background. Defaults to 0.7.
111118 gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
112119 epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
120+ include_background: whether to include background class in loss calculation. Defaults to True.
113121 """
114122 super ().__init__ (reduction = LossReduction (reduction ).value )
115123 self .to_onehot_y = to_onehot_y
116124 self .delta = delta
117125 self .gamma = gamma
118126 self .epsilon = epsilon
127+ self .include_background = include_background
119128
120129 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
121130 n_pred_ch = y_pred .shape [1 ]
@@ -138,6 +147,9 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
138147 fore_ce = cross_entropy [:, 1 ]
139148 fore_ce = self .delta * fore_ce
140149
150+ if not self .include_background :
151+ back_ce = back_ce * 0.0
152+
141153 loss = torch .mean (torch .sum (torch .stack ([back_ce , fore_ce ], dim = 1 ), dim = 1 ))
142154 return loss
143155
0 commit comments