|
| 1 | +import argparse |
| 2 | +import os |
| 3 | +from timeit import default_timer as timer |
| 4 | +from tqdm import tqdm |
| 5 | +import torch |
| 6 | +import torch.utils.data |
| 7 | +import torchvision.transforms as transforms |
| 8 | +import torchvision.datasets as datasets |
| 9 | + |
| 10 | + |
| 11 | +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') |
| 12 | +parser.add_argument('--data', metavar='PATH', required=True, |
| 13 | + help='path to dataset') |
| 14 | +parser.add_argument('--nThreads', '-j', default=2, type=int, metavar='N', |
| 15 | + help='number of data loading threads (default: 2)') |
| 16 | +parser.add_argument('--batchSize', '-b', default=256, type=int, metavar='N', |
| 17 | + help='mini-batch size (1 = pure stochastic) Default: 256') |
| 18 | + |
| 19 | + |
| 20 | +if __name__ == "__main__": |
| 21 | + args = parser.parse_args() |
| 22 | + |
| 23 | + |
| 24 | + # Data loading code |
| 25 | + transform = transforms.Compose([ |
| 26 | + transforms.RandomSizedCrop(224), |
| 27 | + transforms.RandomHorizontalFlip(), |
| 28 | + transforms.ToTensor(), |
| 29 | + transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], |
| 30 | + std = [ 0.229, 0.224, 0.225 ]), |
| 31 | + ]) |
| 32 | + |
| 33 | + traindir = os.path.join(args.data, 'train') |
| 34 | + valdir = os.path.join(args.data, 'val') |
| 35 | + train = datasets.ImageFolder(traindir, transform) |
| 36 | + val = datasets.ImageFolder(valdir, transform) |
| 37 | + train_loader = torch.utils.data.DataLoader( |
| 38 | + train, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads) |
| 39 | + train_iter = iter(train_loader) |
| 40 | + |
| 41 | + start_time = timer() |
| 42 | + batch_count = 100 * args.nThreads |
| 43 | + for i in tqdm(xrange(batch_count)): |
| 44 | + batch = next(train_iter) |
| 45 | + end_time = timer() |
| 46 | + print("Performance: {dataset:.0f} minutes/dataset, {batch:.2f} secs/batch, {image:.2f} ms/image".format( |
| 47 | + dataset=(end_time - start_time) * len(train_loader) / (batch_count * args.batchSize) / 60.0, |
| 48 | + batch=(end_time - start_time) / float(batch_count), |
| 49 | + image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e+3)) |
| 50 | + |
0 commit comments