Skip to content

Commit 4d247b0

Browse files
committed
adding marat's bench script
1 parent 9bbfa1c commit 4d247b0

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

test/preprocess-bench.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import argparse
2+
import os
3+
from timeit import default_timer as timer
4+
from tqdm import tqdm
5+
import torch
6+
import torch.utils.data
7+
import torchvision.transforms as transforms
8+
import torchvision.datasets as datasets
9+
10+
11+
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
12+
parser.add_argument('--data', metavar='PATH', required=True,
13+
help='path to dataset')
14+
parser.add_argument('--nThreads', '-j', default=2, type=int, metavar='N',
15+
help='number of data loading threads (default: 2)')
16+
parser.add_argument('--batchSize', '-b', default=256, type=int, metavar='N',
17+
help='mini-batch size (1 = pure stochastic) Default: 256')
18+
19+
20+
if __name__ == "__main__":
21+
args = parser.parse_args()
22+
23+
24+
# Data loading code
25+
transform = transforms.Compose([
26+
transforms.RandomSizedCrop(224),
27+
transforms.RandomHorizontalFlip(),
28+
transforms.ToTensor(),
29+
transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
30+
std = [ 0.229, 0.224, 0.225 ]),
31+
])
32+
33+
traindir = os.path.join(args.data, 'train')
34+
valdir = os.path.join(args.data, 'val')
35+
train = datasets.ImageFolder(traindir, transform)
36+
val = datasets.ImageFolder(valdir, transform)
37+
train_loader = torch.utils.data.DataLoader(
38+
train, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)
39+
train_iter = iter(train_loader)
40+
41+
start_time = timer()
42+
batch_count = 100 * args.nThreads
43+
for i in tqdm(xrange(batch_count)):
44+
batch = next(train_iter)
45+
end_time = timer()
46+
print("Performance: {dataset:.0f} minutes/dataset, {batch:.2f} secs/batch, {image:.2f} ms/image".format(
47+
dataset=(end_time - start_time) * len(train_loader) / (batch_count * args.batchSize) / 60.0,
48+
batch=(end_time - start_time) / float(batch_count),
49+
image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e+3))
50+

0 commit comments

Comments
 (0)