Skip to content

Commit 0d115af

Browse files
committed
Adding mixup/cutmix in references script.
1 parent c1bc525 commit 0d115af

File tree

3 files changed

+20
-15
lines changed

3 files changed

+20
-15
lines changed

references/classification/train.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,16 @@ def main(args):
165165
train_dir = os.path.join(args.data_path, 'train')
166166
val_dir = os.path.join(args.data_path, 'val')
167167
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
168+
169+
collate_fn = None
170+
if args.mixup_alpha > 0.0 or args.cutmix_alpha > 0.0:
171+
mixupcutmix = torchvision.transforms.RandomMixupCutmix(len(dataset.classes), mixup_alpha=args.mixup_alpha,
172+
cutmix_alpha=args.cutmix_alpha)
173+
collate_fn = lambda batch: mixupcutmix(*(torch.utils.data._utils.collate.default_collate(batch))) # noqa: E731
168174
data_loader = torch.utils.data.DataLoader(
169175
dataset, batch_size=args.batch_size,
170-
sampler=train_sampler, num_workers=args.workers, pin_memory=True)
171-
176+
sampler=train_sampler, num_workers=args.workers, pin_memory=True,
177+
collate_fn=collate_fn)
172178
data_loader_test = torch.utils.data.DataLoader(
173179
dataset_test, batch_size=args.batch_size,
174180
sampler=test_sampler, num_workers=args.workers, pin_memory=True)
@@ -254,7 +260,6 @@ def main(args):
254260
def get_args_parser(add_help=True):
255261
import argparse
256262
parser = argparse.ArgumentParser(description='PyTorch Classification Training', add_help=add_help)
257-
258263
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset')
259264
parser.add_argument('--model', default='resnet18', help='model')
260265
parser.add_argument('--device', default='cuda', help='device')
@@ -273,6 +278,8 @@ def get_args_parser(add_help=True):
273278
parser.add_argument('--label-smoothing', default=0.0, type=float,
274279
help='label smoothing (default: 0.0)',
275280
dest='label_smoothing')
281+
parser.add_argument('--mixup-alpha', default=0.0, type=float, help='mixup alpha (default: 0.0)')
282+
parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)')
276283
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
277284
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
278285
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
@@ -306,7 +313,6 @@ def get_args_parser(add_help=True):
306313
)
307314
parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)')
308315
parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)')
309-
310316
# Mixed precision training parameters
311317
parser.add_argument('--apex', action='store_true',
312318
help='Use apex for mixed precision training')
@@ -315,7 +321,6 @@ def get_args_parser(add_help=True):
315321
'O0 for FP32 training, O1 for mixed precision training.'
316322
'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
317323
)
318-
319324
# distributed training parameters
320325
parser.add_argument('--world-size', default=1, type=int,
321326
help='number of distributed processes')
@@ -326,7 +331,6 @@ def get_args_parser(add_help=True):
326331
parser.add_argument(
327332
'--model-ema-decay', type=float, default=0.99,
328333
help='decay factor for Exponential Moving Average of model parameters(default: 0.99)')
329-
330334
return parser
331335

332336

test/test_transforms_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,5 +788,5 @@ def test_random_mixupcutmix_with_real_data():
788788

789789
torch.testing.assert_close(
790790
torch.stack(stats).mean(dim=0),
791-
torch.tensor([46.931968688964844, 69.97343444824219, 0.459820032119751])
791+
torch.tensor([46.9443473815918, 64.79092407226562, 0.459820032119751])
792792
)

torchvision/transforms/transforms.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1955,8 +1955,9 @@ def __repr__(self):
19551955
return self.__class__.__name__ + '(p={})'.format(self.p)
19561956

19571957

1958+
# TODO: move this to references before merging and delete the tests
19581959
class RandomMixupCutmix(torch.nn.Module):
1959-
"""Randomly apply Mixum or Cutmix to the provided batch and targets.
1960+
"""Randomly apply Mixup or Cutmix to the provided batch and targets.
19601961
The class implements the data augmentations as described in the papers
19611962
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_ and
19621963
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
@@ -2014,8 +2015,8 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
20142015
return batch, target
20152016

20162017
# It's faster to roll the batch by one instead of shuffling it to create image pairs
2017-
batch_flipped = batch.roll(1)
2018-
target_flipped = target.roll(1)
2018+
batch_rolled = batch.roll(1, 0)
2019+
target_rolled = target.roll(1)
20192020

20202021
if self.mixup_alpha <= 0.0:
20212022
use_mixup = False
@@ -2025,8 +2026,8 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
20252026
if use_mixup:
20262027
# Implemented as on mixup paper, page 3.
20272028
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.mixup_alpha, self.mixup_alpha]))[0])
2028-
batch_flipped.mul_(1.0 - lambda_param)
2029-
batch.mul_(lambda_param).add_(batch_flipped)
2029+
batch_rolled.mul_(1.0 - lambda_param)
2030+
batch.mul_(lambda_param).add_(batch_rolled)
20302031
else:
20312032
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
20322033
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.cutmix_alpha, self.cutmix_alpha]))[0])
@@ -2044,11 +2045,11 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
20442045
x2 = int(torch.clamp(r_x + r_w_half, max=W))
20452046
y2 = int(torch.clamp(r_y + r_h_half, max=H))
20462047

2047-
batch[:, :, y1:y2, x1:x2] = batch_flipped[:, :, y1:y2, x1:x2]
2048+
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
20482049
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
20492050

2050-
target_flipped.mul_(1.0 - lambda_param)
2051-
target.mul_(lambda_param).add_(target_flipped)
2051+
target_rolled.mul_(1.0 - lambda_param)
2052+
target.mul_(lambda_param).add_(target_rolled)
20522053

20532054
return batch, target
20542055

0 commit comments

Comments
 (0)