Skip to content

Commit e4d130f

Browse files
committed
Adding auto-augment and random-erase in the training scripts.
1 parent 5198385 commit e4d130f

File tree

1 file changed

+23
-14
lines changed

1 file changed

+23
-14
lines changed

references/classification/train.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _get_cache_path(filepath):
7979
return cache_path
8080

8181

82-
def load_data(traindir, valdir, cache_dataset, distributed):
82+
def load_data(traindir, valdir, args):
8383
# Data loading code
8484
print("Loading data")
8585
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
@@ -88,28 +88,36 @@ def load_data(traindir, valdir, cache_dataset, distributed):
8888
print("Loading training data")
8989
st = time.time()
9090
cache_path = _get_cache_path(traindir)
91-
if cache_dataset and os.path.exists(cache_path):
91+
if args.cache_dataset and os.path.exists(cache_path):
9292
# Attention, as the transforms are also cached!
9393
print("Loading dataset_train from {}".format(cache_path))
9494
dataset, _ = torch.load(cache_path)
9595
else:
96+
trans = [
97+
transforms.RandomResizedCrop(224),
98+
transforms.RandomHorizontalFlip(),
99+
]
100+
if args.auto_augment is not None:
101+
aa_policy = transforms.AutoAugmentPolicy(args.auto_augment)
102+
trans.append(transforms.AutoAugment(policy=aa_policy))
103+
trans.extend([
104+
transforms.ToTensor(),
105+
normalize,
106+
])
107+
if args.random_erase > 0:
108+
trans.append(transforms.RandomErasing(p=args.random_erase))
96109
dataset = torchvision.datasets.ImageFolder(
97110
traindir,
98-
transforms.Compose([
99-
transforms.RandomResizedCrop(224),
100-
transforms.RandomHorizontalFlip(),
101-
transforms.ToTensor(),
102-
normalize,
103-
]))
104-
if cache_dataset:
111+
transforms.Compose(trans))
112+
if args.cache_dataset:
105113
print("Saving dataset_train to {}".format(cache_path))
106114
utils.mkdir(os.path.dirname(cache_path))
107115
utils.save_on_master((dataset, traindir), cache_path)
108116
print("Took", time.time() - st)
109117

110118
print("Loading validation data")
111119
cache_path = _get_cache_path(valdir)
112-
if cache_dataset and os.path.exists(cache_path):
120+
if args.cache_dataset and os.path.exists(cache_path):
113121
# Attention, as the transforms are also cached!
114122
print("Loading dataset_test from {}".format(cache_path))
115123
dataset_test, _ = torch.load(cache_path)
@@ -122,13 +130,13 @@ def load_data(traindir, valdir, cache_dataset, distributed):
122130
transforms.ToTensor(),
123131
normalize,
124132
]))
125-
if cache_dataset:
133+
if args.cache_dataset:
126134
print("Saving dataset_test to {}".format(cache_path))
127135
utils.mkdir(os.path.dirname(cache_path))
128136
utils.save_on_master((dataset_test, valdir), cache_path)
129137

130138
print("Creating data loaders")
131-
if distributed:
139+
if args.distributed:
132140
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
133141
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
134142
else:
@@ -155,8 +163,7 @@ def main(args):
155163

156164
train_dir = os.path.join(args.data_path, 'train')
157165
val_dir = os.path.join(args.data_path, 'val')
158-
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir,
159-
args.cache_dataset, args.distributed)
166+
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
160167
data_loader = torch.utils.data.DataLoader(
161168
dataset, batch_size=args.batch_size,
162169
sampler=train_sampler, num_workers=args.workers, pin_memory=True)
@@ -283,6 +290,8 @@ def parse_args():
283290
help="Use pre-trained models from the modelzoo",
284291
action="store_true",
285292
)
293+
parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)')
294+
parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)')
286295

287296
# Mixed precision training parameters
288297
parser.add_argument('--apex', action='store_true',

0 commit comments

Comments
 (0)