@@ -1971,15 +1971,13 @@ class RandomMixupCutmix(torch.nn.Module):
1971
1971
Default value is 0.5.
1972
1972
cutmix_alpha (float): hyperparameter of the Beta distribution used for cutmix.
1973
1973
Set to 0.0 to turn off. Default value is 0.0.
1974
- label_smoothing (float): the amount of smoothing using when one-hot encoding.
1975
- Set to 0.0 to turn off. Default value is 0.0.
1976
1974
inplace (bool): boolean to make this transform inplace. Default set to False.
1977
1975
"""
1978
1976
1979
1977
def __init__ (self , num_classes : int ,
1980
1978
p : float = 1.0 , mixup_alpha : float = 1.0 ,
1981
1979
cutmix_p : float = 0.5 , cutmix_alpha : float = 0.0 ,
1982
- label_smoothing : float = 0.0 , inplace : bool = False ) -> None :
1980
+ inplace : bool = False ) -> None :
1983
1981
super ().__init__ ()
1984
1982
assert num_classes > 0 , "Please provide a valid positive value for the num_classes."
1985
1983
assert mixup_alpha > 0 or cutmix_alpha > 0 , "Both alpha params can't be zero."
@@ -1989,16 +1987,8 @@ def __init__(self, num_classes: int,
1989
1987
self .mixup_alpha = mixup_alpha
1990
1988
self .cutmix_p = cutmix_p
1991
1989
self .cutmix_alpha = cutmix_alpha
1992
- self .label_smoothing = label_smoothing
1993
1990
self .inplace = inplace
1994
1991
1995
- def _smooth_one_hot (self , target : Tensor ) -> Tensor :
1996
- N = target .shape [0 ]
1997
- device = target .device
1998
- v = torch .full (size = (N , 1 ), fill_value = 1 - self .label_smoothing , device = device )
1999
- return torch .full (size = (N , self .num_classes ), fill_value = self .label_smoothing / self .num_classes ,
2000
- device = device ).scatter_add_ (1 , target .unsqueeze (1 ), v )
2001
-
2002
1992
def forward (self , batch : Tensor , target : Tensor ) -> Tuple [Tensor , Tensor ]:
2003
1993
"""
2004
1994
Args:
@@ -2014,21 +2004,18 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
2014
2004
raise ValueError ("Target ndim should be 1. Got {}" .format (target .ndim ))
2015
2005
elif target .dtype != torch .int64 :
2016
2006
raise ValueError ("Target dtype should be torch.int64. Got {}" .format (target .dtype ))
2017
- elif batch .size (0 ) % 2 != 0 :
2018
- # speed optimization, see below
2019
- raise ValueError ("The batch size should be even." )
2020
2007
2021
2008
if not self .inplace :
2022
2009
batch = batch .clone ()
2023
2010
# target = target.clone()
2024
2011
2025
- target = self . _smooth_one_hot (target )
2012
+ target = torch . nn . functional . one_hot (target , num_classes = self . num_classes ). to ( dtype = torch . float32 )
2026
2013
if torch .rand (1 ).item () >= self .p :
2027
2014
return batch , target
2028
2015
2029
- # It's faster to flip the batch instead of shuffling it to create image pairs
2030
- batch_flipped = batch .flip ( 0 )
2031
- target_flipped = target .flip ( 0 )
2016
+ # It's faster to roll the batch by one instead of shuffling it to create image pairs
2017
+ batch_flipped = batch .roll ( 1 )
2018
+ target_flipped = target .roll ( 1 )
2032
2019
2033
2020
if self .mixup_alpha <= 0.0 :
2034
2021
use_mixup = False
@@ -2072,7 +2059,6 @@ def __repr__(self) -> str:
2072
2059
s += ', mixup_alpha={mixup_alpha}'
2073
2060
s += ', cutmix_p={cutmix_p}'
2074
2061
s += ', cutmix_alpha={cutmix_alpha}'
2075
- s += ', label_smoothing={label_smoothing}'
2076
2062
s += ', inplace={inplace}'
2077
2063
s += ')'
2078
2064
return s .format (** self .__dict__ )
0 commit comments