Skip to content

Commit c9c02cb

Browse files
monatisellonde
andauthored
benchmark: add script to prepare imagenet1k for benchmarking (#72)
* Imagenet1k benchmark * Add note about benchmarking --------- Co-authored-by: Mathias <[email protected]>
1 parent 030e4e8 commit c9c02cb

File tree

4 files changed

+197
-1
lines changed

4 files changed

+197
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ num_images_per_dir: maximum number of images to read from each one of subdirecto
192192
output_file: optional. if specified, dump the output to this file instead of stdout
193193
```
194194

195-
TODO: share benchmarking results for a common dataset later on.
195+
See [tests/README.md](tests/README.md) for more infor about benchmarking.
196196

197197
## Future Work
198198

tests/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
## Tests
2+
3+
You can use `prepare_imagenet1k.py` to download and prepare the imagenet1k dataset
4+
in a format expected by the `benchmark` utility.
5+
If you haven't already, you need to install torch and torchvision to
6+
use this Python script:
7+
8+
```sh
9+
pip install -r requirements.txt
10+
```
11+
12+
## Note about benchmark results
13+
Please note that the results in this benchmark do not match those reported in the open-clip repository because:
14+
15+
1. Most importantly, they use a different test protocol that includes averaging vectors of text templates etc.
16+
2. There are still gatchas in the tokenization implementation in this repo.
17+
3. This repo uses a linear interpolation instead of bicubic in image preprocessing.
18+
19+
The 2nd and 3rd items will be fixed soon.
20+
I don't agree with their test protocol, so I am not so motivated to fix the first item.

tests/prepare_imagenet1k.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""
2+
Small script to get and parse imagenet1k dataset into benchmark format
3+
4+
Dataset comments
5+
Change classes names containing "/" to "or"
6+
Some classes have '/' in their name
7+
For compatibility with folder benchmarks we replace them with 'or'
8+
Skip classes (744, missiles) and (837, sunglasses) as they are duplicates
9+
10+
"""
11+
12+
import argparse
13+
import json
14+
import os
15+
from pathlib import Path
16+
import shutil
17+
from subprocess import call
18+
from torchvision.datasets import ImageNet
19+
20+
21+
# Files
22+
_CLASSNAMES_FILENAME = "classnames.json"
23+
_CLASSTEMPLATES_FILENAME = "class_templates.json"
24+
_DEVKIT_FILENAME = "ILSVRC2012_devkit_t12.tar.gz"
25+
_IMG_VAL_FILENAME = "ILSVRC2012_img_val.tar"
26+
27+
# Name for folder with final dataset
28+
_PROCESSED_DIR_NAME = "dataset"
29+
30+
31+
def download_dataset(path: Path, verbose: bool = False):
32+
if verbose:
33+
print("Downloading dataset")
34+
path.mkdir(exist_ok=True, parents=True)
35+
36+
dk_output_path = path / _DEVKIT_FILENAME
37+
iv_output_path = path / _IMG_VAL_FILENAME
38+
39+
template_path = path / _CLASSTEMPLATES_FILENAME
40+
classnames_path = path / _CLASSNAMES_FILENAME
41+
42+
if not dk_output_path.exists():
43+
if verbose:
44+
print("\tDidnt find devkit file, downloading..")
45+
call(
46+
(
47+
f"wget https://image-net.org/data/ILSVRC/2012/{_DEVKIT_FILENAME} "
48+
+ f"--output-document={dk_output_path}"
49+
),
50+
shell=True,
51+
)
52+
else:
53+
if verbose:
54+
print("\tFound devkit file, skipping download..")
55+
56+
if not iv_output_path.exists():
57+
if verbose:
58+
print("\tDidnt find image validation file, downloading..")
59+
call(
60+
(
61+
f"wget https://image-net.org/data/ILSVRC/2012/{_IMG_VAL_FILENAME} "
62+
+ f"--output-document={iv_output_path}"
63+
),
64+
shell=True,
65+
)
66+
else:
67+
if verbose:
68+
print("\tFound image validation file, skipping download..")
69+
70+
if not template_path.exists():
71+
if verbose:
72+
print("\tDidnt find class templates file, downloading..")
73+
call(
74+
(
75+
"wget "
76+
+ "https://raw.githubusercontent.com/LAION-AI/CLIP_benchmark/main/clip_benchmark/datasets/en_zeroshot_classification_templates.json "
77+
+ f"--output-document={template_path}"
78+
),
79+
shell=True,
80+
)
81+
82+
class_templates = json.load(template_path.open("r"))
83+
class_templates = class_templates["imagenet1k"]
84+
json.dump(class_templates, template_path.open("w"), indent=2)
85+
else:
86+
if verbose:
87+
print("\tFound class templates file, skipping download..")
88+
89+
if not classnames_path.exists():
90+
if verbose:
91+
print("\tDidnt find class names file, downloading..")
92+
call(
93+
(
94+
"wget "
95+
+ "https://raw.githubusercontent.com/LAION-AI/CLIP_benchmark/main/clip_benchmark/datasets/en_classnames.json "
96+
+ f"--output-document={classnames_path}"
97+
),
98+
shell=True,
99+
)
100+
classnames = json.load(classnames_path.open("r"))
101+
classnames = classnames["imagenet1k"]
102+
103+
if verbose:
104+
print(
105+
"\tFixing classnames, replacing '/' with 'or' and removing duplicates.."
106+
)
107+
# Described in top comment section
108+
classnames = [
109+
c.replace("/", "or")
110+
for i, c in enumerate(classnames)
111+
if i not in [744, 837]
112+
]
113+
114+
json.dump(classnames, classnames_path.open("w"), indent=2)
115+
116+
117+
def parse_dataset(path: Path, verbose=False):
118+
if verbose:
119+
print("Parsing dataset")
120+
# Load cases
121+
classes_path = path.joinpath(_CLASSNAMES_FILENAME)
122+
classes = json.load(classes_path.open("r"))
123+
124+
# Check if dataset has already been processed
125+
processed_dataset_path = path / _PROCESSED_DIR_NAME
126+
dataset_exists = all(processed_dataset_path.joinpath(c).exists() for c in classes)
127+
128+
if dataset_exists:
129+
return processed_dataset_path
130+
131+
processed_dataset_path.mkdir(exist_ok=True)
132+
133+
# ImageNet dataset handles the parsing
134+
if verbose:
135+
print("\tUnpacking dataset, this can take a bit..")
136+
ds = ImageNet(root=path, split="val")
137+
138+
# Track with counter as some classes are removed from classes
139+
cls_index = 0
140+
for i, dir_name in enumerate(ds.wnids):
141+
if dir_name in ["n04356056", "n04008634"]:
142+
if verbose:
143+
print("\tSkipped class", ds.classes[i])
144+
continue
145+
146+
class_name = classes[cls_index]
147+
src_dir = Path(ds.split_folder).joinpath(dir_name)
148+
dst_dir = processed_dataset_path.joinpath(class_name)
149+
150+
os.rename(src=src_dir, dst=dst_dir)
151+
if verbose:
152+
print(f"\tMoved class: {ds.classes[i]} to {class_name}")
153+
154+
cls_index += 1
155+
156+
# Remove other files
157+
shutil.rmtree(ds.split_folder)
158+
if verbose:
159+
print("\tCleaned up unpacked dataset folder")
160+
161+
return processed_dataset_path
162+
163+
164+
if __name__ == "__main__":
165+
parser = argparse.ArgumentParser()
166+
parser.add_argument("--save_path", type=str, required=True)
167+
parser.add_argument("--verbose", action=argparse.BooleanOptionalAction)
168+
args = parser.parse_args()
169+
170+
path = Path(args.save_path).absolute()
171+
172+
download_dataset(path=path, verbose=args.verbose)
173+
dataset_path = parse_dataset(path=path, verbose=args.verbose)
174+
print(f"Dataset is ready at {dataset_path}")

tests/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torch
2+
torchvision

0 commit comments

Comments
 (0)