@@ -165,10 +165,16 @@ def main(args):
165
165
train_dir = os .path .join (args .data_path , 'train' )
166
166
val_dir = os .path .join (args .data_path , 'val' )
167
167
dataset , dataset_test , train_sampler , test_sampler = load_data (train_dir , val_dir , args )
168
+
169
+ collate_fn = None
170
+ if args .mixup_alpha > 0.0 or args .cutmix_alpha > 0.0 :
171
+ mixupcutmix = torchvision .transforms .RandomMixupCutmix (len (dataset .classes ), mixup_alpha = args .mixup_alpha ,
172
+ cutmix_alpha = args .cutmix_alpha )
173
+ collate_fn = lambda batch : mixupcutmix (* (torch .utils .data ._utils .collate .default_collate (batch ))) # noqa: E731
168
174
data_loader = torch .utils .data .DataLoader (
169
175
dataset , batch_size = args .batch_size ,
170
- sampler = train_sampler , num_workers = args .workers , pin_memory = True )
171
-
176
+ sampler = train_sampler , num_workers = args .workers , pin_memory = True ,
177
+ collate_fn = collate_fn )
172
178
data_loader_test = torch .utils .data .DataLoader (
173
179
dataset_test , batch_size = args .batch_size ,
174
180
sampler = test_sampler , num_workers = args .workers , pin_memory = True )
@@ -254,7 +260,6 @@ def main(args):
254
260
def get_args_parser (add_help = True ):
255
261
import argparse
256
262
parser = argparse .ArgumentParser (description = 'PyTorch Classification Training' , add_help = add_help )
257
-
258
263
parser .add_argument ('--data-path' , default = '/datasets01/imagenet_full_size/061417/' , help = 'dataset' )
259
264
parser .add_argument ('--model' , default = 'resnet18' , help = 'model' )
260
265
parser .add_argument ('--device' , default = 'cuda' , help = 'device' )
@@ -273,6 +278,8 @@ def get_args_parser(add_help=True):
273
278
parser .add_argument ('--label-smoothing' , default = 0.0 , type = float ,
274
279
help = 'label smoothing (default: 0.0)' ,
275
280
dest = 'label_smoothing' )
281
+ parser .add_argument ('--mixup-alpha' , default = 0.0 , type = float , help = 'mixup alpha (default: 0.0)' )
282
+ parser .add_argument ('--cutmix-alpha' , default = 0.0 , type = float , help = 'cutmix alpha (default: 0.0)' )
276
283
parser .add_argument ('--lr-step-size' , default = 30 , type = int , help = 'decrease lr every step-size epochs' )
277
284
parser .add_argument ('--lr-gamma' , default = 0.1 , type = float , help = 'decrease lr by a factor of lr-gamma' )
278
285
parser .add_argument ('--print-freq' , default = 10 , type = int , help = 'print frequency' )
@@ -306,7 +313,6 @@ def get_args_parser(add_help=True):
306
313
)
307
314
parser .add_argument ('--auto-augment' , default = None , help = 'auto augment policy (default: None)' )
308
315
parser .add_argument ('--random-erase' , default = 0.0 , type = float , help = 'random erasing probability (default: 0.0)' )
309
-
310
316
# Mixed precision training parameters
311
317
parser .add_argument ('--apex' , action = 'store_true' ,
312
318
help = 'Use apex for mixed precision training' )
@@ -315,7 +321,6 @@ def get_args_parser(add_help=True):
315
321
'O0 for FP32 training, O1 for mixed precision training.'
316
322
'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
317
323
)
318
-
319
324
# distributed training parameters
320
325
parser .add_argument ('--world-size' , default = 1 , type = int ,
321
326
help = 'number of distributed processes' )
@@ -326,7 +331,6 @@ def get_args_parser(add_help=True):
326
331
parser .add_argument (
327
332
'--model-ema-decay' , type = float , default = 0.99 ,
328
333
help = 'decay factor for Exponential Moving Average of model parameters(default: 0.99)' )
329
-
330
334
return parser
331
335
332
336
0 commit comments