Skip to content

Commit 6f2ebea

Browse files
committed
Move out label_smoothing and try roll instead of flip
1 parent c15dce0 commit 6f2ebea

File tree

2 files changed

+9
-26
lines changed

2 files changed

+9
-26
lines changed

test/test_transforms_tensor.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -728,15 +728,14 @@ def test_gaussian_blur(device, meth_kwargs):
728728
{"mixup_alpha": 0.0, "cutmix_alpha": 1.0},
729729
{"mixup_alpha": 1.0, "cutmix_alpha": 0.0},
730730
])
731-
@pytest.mark.parametrize('label_smoothing', [0.0, 0.1])
732731
@pytest.mark.parametrize('inplace', [True, False])
733-
def test_random_mixupcutmix(device, alphas, label_smoothing, inplace):
732+
def test_random_mixupcutmix(device, alphas, inplace):
734733
batch_size = 32
735734
num_classes = 10
736735
batch = torch.rand(batch_size, 3, 44, 56, device=device)
737736
targets = torch.randint(num_classes, (batch_size, ), device=device, dtype=torch.int64)
738737

739-
fn = T.RandomMixupCutmix(num_classes, label_smoothing=label_smoothing, inplace=inplace, **alphas)
738+
fn = T.RandomMixupCutmix(num_classes, inplace=inplace, **alphas)
740739
scripted_fn = torch.jit.script(fn)
741740

742741
seed = torch.seed()
@@ -763,8 +762,6 @@ def test_random_mixupcutmix_with_invalid_data():
763762
t(torch.rand(32, 3, 60, 60), torch.randint(10, (32, 1)))
764763
with pytest.raises(ValueError, match="Target dtype should be torch.int64."):
765764
t(torch.rand(32, 3, 60, 60), torch.randint(10, (32, ), dtype=torch.int32))
766-
with pytest.raises(ValueError, match="The batch size should be even."):
767-
t(torch.rand(31, 3, 60, 60), torch.randint(10, (31, )))
768765

769766

770767
def test_random_mixupcutmix_with_real_data():
@@ -779,7 +776,7 @@ def test_random_mixupcutmix_with_real_data():
779776
dataset = TensorDataset(torch.stack(images).to(torch.float32), torch.tensor([0, 1]))
780777

781778
# Use mixup in the collate
782-
mixup = T.RandomMixupCutmix(2, cutmix_alpha=1.0, mixup_alpha=1.0, label_smoothing=0.1)
779+
mixup = T.RandomMixupCutmix(2, cutmix_alpha=1.0, mixup_alpha=1.0)
783780
dataloader = DataLoader(dataset, batch_size=2,
784781
collate_fn=lambda batch: mixup(*(torch.stack(x) for x in zip(*batch))))
785782

@@ -791,5 +788,5 @@ def test_random_mixupcutmix_with_real_data():
791788

792789
torch.testing.assert_close(
793790
torch.stack(stats).mean(dim=0),
794-
torch.tensor([46.94434738, 64.79092407, 0.23949696])
791+
torch.tensor([46.931968688964844, 69.97343444824219, 0.459820032119751])
795792
)

torchvision/transforms/transforms.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1971,15 +1971,13 @@ class RandomMixupCutmix(torch.nn.Module):
19711971
Default value is 0.5.
19721972
cutmix_alpha (float): hyperparameter of the Beta distribution used for cutmix.
19731973
Set to 0.0 to turn off. Default value is 0.0.
1974-
label_smoothing (float): the amount of smoothing using when one-hot encoding.
1975-
Set to 0.0 to turn off. Default value is 0.0.
19761974
inplace (bool): boolean to make this transform inplace. Default set to False.
19771975
"""
19781976

19791977
def __init__(self, num_classes: int,
19801978
p: float = 1.0, mixup_alpha: float = 1.0,
19811979
cutmix_p: float = 0.5, cutmix_alpha: float = 0.0,
1982-
label_smoothing: float = 0.0, inplace: bool = False) -> None:
1980+
inplace: bool = False) -> None:
19831981
super().__init__()
19841982
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
19851983
assert mixup_alpha > 0 or cutmix_alpha > 0, "Both alpha params can't be zero."
@@ -1989,16 +1987,8 @@ def __init__(self, num_classes: int,
19891987
self.mixup_alpha = mixup_alpha
19901988
self.cutmix_p = cutmix_p
19911989
self.cutmix_alpha = cutmix_alpha
1992-
self.label_smoothing = label_smoothing
19931990
self.inplace = inplace
19941991

1995-
def _smooth_one_hot(self, target: Tensor) -> Tensor:
1996-
N = target.shape[0]
1997-
device = target.device
1998-
v = torch.full(size=(N, 1), fill_value=1 - self.label_smoothing, device=device)
1999-
return torch.full(size=(N, self.num_classes), fill_value=self.label_smoothing / self.num_classes,
2000-
device=device).scatter_add_(1, target.unsqueeze(1), v)
2001-
20021992
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
20031993
"""
20041994
Args:
@@ -2014,21 +2004,18 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
20142004
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim))
20152005
elif target.dtype != torch.int64:
20162006
raise ValueError("Target dtype should be torch.int64. Got {}".format(target.dtype))
2017-
elif batch.size(0) % 2 != 0:
2018-
# speed optimization, see below
2019-
raise ValueError("The batch size should be even.")
20202007

20212008
if not self.inplace:
20222009
batch = batch.clone()
20232010
# target = target.clone()
20242011

2025-
target = self._smooth_one_hot(target)
2012+
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=torch.float32)
20262013
if torch.rand(1).item() >= self.p:
20272014
return batch, target
20282015

2029-
# It's faster to flip the batch instead of shuffling it to create image pairs
2030-
batch_flipped = batch.flip(0)
2031-
target_flipped = target.flip(0)
2016+
# It's faster to roll the batch by one instead of shuffling it to create image pairs
2017+
batch_flipped = batch.roll(1)
2018+
target_flipped = target.roll(1)
20322019

20332020
if self.mixup_alpha <= 0.0:
20342021
use_mixup = False
@@ -2072,7 +2059,6 @@ def __repr__(self) -> str:
20722059
s += ', mixup_alpha={mixup_alpha}'
20732060
s += ', cutmix_p={cutmix_p}'
20742061
s += ', cutmix_alpha={cutmix_alpha}'
2075-
s += ', label_smoothing={label_smoothing}'
20762062
s += ', inplace={inplace}'
20772063
s += ')'
20782064
return s.format(**self.__dict__)

0 commit comments

Comments
 (0)