Skip to content

Commit f3ddbf5

Browse files
committed
Fix bug on reference script.
1 parent a4ec036 commit f3ddbf5

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

references/classification/train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,12 @@ def load_data(traindir, valdir, args):
9292
print("Loading dataset_train from {}".format(cache_path))
9393
dataset, _ = torch.load(cache_path)
9494
else:
95+
auto_augment_policy = getattr(args, "auto_augment", None)
96+
random_erase_prob = getattr(args, "random_erase", None)
9597
dataset = torchvision.datasets.ImageFolder(
9698
traindir,
97-
presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=args.auto_augment,
98-
random_erase_prob=args.random_erase))
99+
presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=auto_augment_policy,
100+
random_erase_prob=random_erase_prob))
99101
if args.cache_dataset:
100102
print("Saving dataset_train to {}".format(cache_path))
101103
utils.mkdir(os.path.dirname(cache_path))

references/classification/train_quantization.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ def main(args):
3737
train_dir = os.path.join(args.data_path, 'train')
3838
val_dir = os.path.join(args.data_path, 'val')
3939

40-
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir,
41-
args.cache_dataset, args.distributed)
40+
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
4241
data_loader = torch.utils.data.DataLoader(
4342
dataset, batch_size=args.batch_size,
4443
sampler=train_sampler, num_workers=args.workers, pin_memory=True)

0 commit comments

Comments
 (0)