@@ -79,7 +79,7 @@ def _get_cache_path(filepath):
79
79
return cache_path
80
80
81
81
82
- def load_data (traindir , valdir , cache_dataset , distributed ):
82
+ def load_data (traindir , valdir , args ):
83
83
# Data loading code
84
84
print ("Loading data" )
85
85
normalize = transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ],
@@ -88,28 +88,36 @@ def load_data(traindir, valdir, cache_dataset, distributed):
88
88
print ("Loading training data" )
89
89
st = time .time ()
90
90
cache_path = _get_cache_path (traindir )
91
- if cache_dataset and os .path .exists (cache_path ):
91
+ if args . cache_dataset and os .path .exists (cache_path ):
92
92
# Attention, as the transforms are also cached!
93
93
print ("Loading dataset_train from {}" .format (cache_path ))
94
94
dataset , _ = torch .load (cache_path )
95
95
else :
96
+ trans = [
97
+ transforms .RandomResizedCrop (224 ),
98
+ transforms .RandomHorizontalFlip (),
99
+ ]
100
+ if args .auto_augment is not None :
101
+ aa_policy = transforms .AutoAugmentPolicy (args .auto_augment )
102
+ trans .append (transforms .AutoAugment (policy = aa_policy ))
103
+ trans .extend ([
104
+ transforms .ToTensor (),
105
+ normalize ,
106
+ ])
107
+ if args .random_erase > 0 :
108
+ trans .append (transforms .RandomErasing (p = args .random_erase ))
96
109
dataset = torchvision .datasets .ImageFolder (
97
110
traindir ,
98
- transforms .Compose ([
99
- transforms .RandomResizedCrop (224 ),
100
- transforms .RandomHorizontalFlip (),
101
- transforms .ToTensor (),
102
- normalize ,
103
- ]))
104
- if cache_dataset :
111
+ transforms .Compose (trans ))
112
+ if args .cache_dataset :
105
113
print ("Saving dataset_train to {}" .format (cache_path ))
106
114
utils .mkdir (os .path .dirname (cache_path ))
107
115
utils .save_on_master ((dataset , traindir ), cache_path )
108
116
print ("Took" , time .time () - st )
109
117
110
118
print ("Loading validation data" )
111
119
cache_path = _get_cache_path (valdir )
112
- if cache_dataset and os .path .exists (cache_path ):
120
+ if args . cache_dataset and os .path .exists (cache_path ):
113
121
# Attention, as the transforms are also cached!
114
122
print ("Loading dataset_test from {}" .format (cache_path ))
115
123
dataset_test , _ = torch .load (cache_path )
@@ -122,13 +130,13 @@ def load_data(traindir, valdir, cache_dataset, distributed):
122
130
transforms .ToTensor (),
123
131
normalize ,
124
132
]))
125
- if cache_dataset :
133
+ if args . cache_dataset :
126
134
print ("Saving dataset_test to {}" .format (cache_path ))
127
135
utils .mkdir (os .path .dirname (cache_path ))
128
136
utils .save_on_master ((dataset_test , valdir ), cache_path )
129
137
130
138
print ("Creating data loaders" )
131
- if distributed :
139
+ if args . distributed :
132
140
train_sampler = torch .utils .data .distributed .DistributedSampler (dataset )
133
141
test_sampler = torch .utils .data .distributed .DistributedSampler (dataset_test )
134
142
else :
@@ -155,8 +163,7 @@ def main(args):
155
163
156
164
train_dir = os .path .join (args .data_path , 'train' )
157
165
val_dir = os .path .join (args .data_path , 'val' )
158
- dataset , dataset_test , train_sampler , test_sampler = load_data (train_dir , val_dir ,
159
- args .cache_dataset , args .distributed )
166
+ dataset , dataset_test , train_sampler , test_sampler = load_data (train_dir , val_dir , args )
160
167
data_loader = torch .utils .data .DataLoader (
161
168
dataset , batch_size = args .batch_size ,
162
169
sampler = train_sampler , num_workers = args .workers , pin_memory = True )
@@ -283,6 +290,8 @@ def parse_args():
283
290
help = "Use pre-trained models from the modelzoo" ,
284
291
action = "store_true" ,
285
292
)
293
+ parser .add_argument ('--auto-augment' , default = None , help = 'auto augment policy (default: None)' )
294
+ parser .add_argument ('--random-erase' , default = 0.0 , type = float , help = 'random erasing probability (default: 0.0)' )
286
295
287
296
# Mixed precision training parameters
288
297
parser .add_argument ('--apex' , action = 'store_true' ,
0 commit comments