Skip to content

Commit d0063f3

Browse files
jgbradley1Joshua Bradleypmeiervfdev-5
authored
Add widerface dataset (#2883)
* initial commit of widerface dataset * comment out old code * improve parsing of annotation files * code cleanup and fix docstring comments * speed up check for quota exceeded * cleanup print statements * reformat code and remove print statements * minor code cleanup and reformatting * add more comments * reuse variable * reverse formatting changes * fix flake8 errors * add type annotations * fix mypy errors * add a base_folder to root directory * some formatting fixes * GDrive threshold does not throw 403 error * testing new download logic * cleanup logic for download and integrity check * use a better variable name * format fix * reorder list in docstring * initial widerface unit test - fails on MD5 check * use list of dictionaries to store dataset * fix docstring formatting * remove unnecessary error checking * fix type checker error * revert typo fix * rename var constants, use file context manager, verify str args * fix flake8 error * fix checking target_type argument values * create uncompressed dataset folders * cleanup unit tests for widerface * use correct os function * add more info to docstring * disable unittests for windows * fix _check_integrity logic * update docstring * remove citation * remove target_type option * fix formatting issue Co-authored-by: Philip Meier <[email protected]> * remove comment and add more info to docstring * update type annotations * restart CI jobs Co-authored-by: Joshua Bradley <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: vfdev <[email protected]>
1 parent 3ee34eb commit d0063f3

File tree

4 files changed

+275
-3
lines changed

4 files changed

+275
-3
lines changed

test/fakedata_generation.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,73 @@ def _make_devkit_archive(root):
171171
yield root
172172

173173

174+
@contextlib.contextmanager
175+
def widerface_root():
176+
"""
177+
Generates a dataset with the following folder structure and returns the path root:
178+
<root>
179+
└── widerface
180+
├── wider_face_split
181+
├── WIDER_train
182+
├── WIDER_val
183+
└── WIDER_test
184+
185+
The dataset consist of
186+
1 image for each dataset split (train, val, test) and annotation files
187+
for each split
188+
"""
189+
190+
def _make_image(file):
191+
PIL.Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8)).save(file)
192+
193+
def _make_train_archive(root):
194+
extracted_dir = os.path.join(root, 'WIDER_train', 'images', '0--Parade')
195+
os.makedirs(extracted_dir)
196+
_make_image(os.path.join(extracted_dir, '0_Parade_marchingband_1_1.jpg'))
197+
198+
def _make_val_archive(root):
199+
extracted_dir = os.path.join(root, 'WIDER_val', 'images', '0--Parade')
200+
os.makedirs(extracted_dir)
201+
_make_image(os.path.join(extracted_dir, '0_Parade_marchingband_1_2.jpg'))
202+
203+
def _make_test_archive(root):
204+
extracted_dir = os.path.join(root, 'WIDER_test', 'images', '0--Parade')
205+
os.makedirs(extracted_dir)
206+
_make_image(os.path.join(extracted_dir, '0_Parade_marchingband_1_3.jpg'))
207+
208+
def _make_annotations_archive(root):
209+
train_bbox_contents = '0--Parade/0_Parade_marchingband_1_1.jpg\n1\n449 330 122 149 0 0 0 0 0 0\n'
210+
val_bbox_contents = '0--Parade/0_Parade_marchingband_1_2.jpg\n1\n501 160 285 443 0 0 0 0 0 0\n'
211+
test_filelist_contents = '0--Parade/0_Parade_marchingband_1_3.jpg\n'
212+
extracted_dir = os.path.join(root, 'wider_face_split')
213+
os.mkdir(extracted_dir)
214+
215+
# bbox training file
216+
bbox_file = os.path.join(extracted_dir, "wider_face_train_bbx_gt.txt")
217+
with open(bbox_file, "w") as txt_file:
218+
txt_file.write(train_bbox_contents)
219+
220+
# bbox validation file
221+
bbox_file = os.path.join(extracted_dir, "wider_face_val_bbx_gt.txt")
222+
with open(bbox_file, "w") as txt_file:
223+
txt_file.write(val_bbox_contents)
224+
225+
# test filelist file
226+
filelist_file = os.path.join(extracted_dir, "wider_face_test_filelist.txt")
227+
with open(filelist_file, "w") as txt_file:
228+
txt_file.write(test_filelist_contents)
229+
230+
with get_tmp_dir() as root:
231+
root_base = os.path.join(root, "widerface")
232+
os.mkdir(root_base)
233+
_make_train_archive(root_base)
234+
_make_val_archive(root_base)
235+
_make_test_archive(root_base)
236+
_make_annotations_archive(root_base)
237+
238+
yield root
239+
240+
174241
@contextlib.contextmanager
175242
def cityscapes_root():
176243

test/test_datasets.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torchvision
1010
from common_utils import get_tmp_dir
1111
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
12-
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root
12+
cityscapes_root, svhn_root, voc_root, ucf101_root, places365_root, widerface_root
1313
import xml.etree.ElementTree as ET
1414
from urllib.request import Request, urlopen
1515
import itertools
@@ -139,6 +139,26 @@ def test_imagenet(self, mock_verify):
139139
dataset = torchvision.datasets.ImageNet(root, split='val')
140140
self.generic_classification_dataset_test(dataset)
141141

142+
@mock.patch('torchvision.datasets.WIDERFace._check_integrity')
143+
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
144+
def test_widerface(self, mock_check_integrity):
145+
mock_check_integrity.return_value = True
146+
with widerface_root() as root:
147+
dataset = torchvision.datasets.WIDERFace(root, split='train')
148+
self.assertEqual(len(dataset), 1)
149+
img, target = dataset[0]
150+
self.assertTrue(isinstance(img, PIL.Image.Image))
151+
152+
dataset = torchvision.datasets.WIDERFace(root, split='val')
153+
self.assertEqual(len(dataset), 1)
154+
img, target = dataset[0]
155+
self.assertTrue(isinstance(img, PIL.Image.Image))
156+
157+
dataset = torchvision.datasets.WIDERFace(root, split='test')
158+
self.assertEqual(len(dataset), 1)
159+
img, target = dataset[0]
160+
self.assertTrue(isinstance(img, PIL.Image.Image))
161+
142162
@mock.patch('torchvision.datasets.cifar.check_integrity')
143163
@mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
144164
def test_cifar10(self, mock_ext_check, mock_int_check):

torchvision/datasets/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .imagenet import ImageNet
1717
from .caltech import Caltech101, Caltech256
1818
from .celeba import CelebA
19+
from .widerface import WIDERFace
1920
from .sbd import SBDataset
2021
from .vision import VisionDataset
2122
from .usps import USPS
@@ -31,5 +32,6 @@
3132
'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
3233
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
3334
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
34-
'Caltech101', 'Caltech256', 'CelebA', 'SBDataset', 'VisionDataset',
35-
'USPS', 'Kinetics400', 'HMDB51', 'UCF101', 'Places365')
35+
'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset',
36+
'VisionDataset', 'USPS', 'Kinetics400', 'HMDB51', 'UCF101',
37+
'Places365')

torchvision/datasets/widerface.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from PIL import Image
2+
import os
3+
from os.path import abspath, expanduser
4+
import torch
5+
from typing import Any, Callable, List, Dict, Optional, Tuple, Union
6+
from .utils import check_integrity, download_file_from_google_drive, \
7+
download_and_extract_archive, extract_archive, verify_str_arg
8+
from .vision import VisionDataset
9+
10+
11+
class WIDERFace(VisionDataset):
12+
"""`WIDERFace <http://shuoyang1213.me/WIDERFACE/>`_ Dataset.
13+
14+
Args:
15+
root (string): Root directory where images and annotations are downloaded to.
16+
Expects the following folder structure if download=False:
17+
<root>
18+
└── widerface
19+
├── wider_face_split ('wider_face_split.zip' if compressed)
20+
├── WIDER_train ('WIDER_train.zip' if compressed)
21+
├── WIDER_val ('WIDER_val.zip' if compressed)
22+
└── WIDER_test ('WIDER_test.zip' if compressed)
23+
split (string): The dataset split to use. One of {``train``, ``val``, ``test``}.
24+
Defaults to ``train``.
25+
transform (callable, optional): A function/transform that takes in a PIL image
26+
and returns a transformed version. E.g, ``transforms.RandomCrop``
27+
target_transform (callable, optional): A function/transform that takes in the
28+
target and transforms it.
29+
download (bool, optional): If true, downloads the dataset from the internet and
30+
puts it in root directory. If dataset is already downloaded, it is not
31+
downloaded again.
32+
"""
33+
34+
BASE_FOLDER = "widerface"
35+
FILE_LIST = [
36+
# File ID MD5 Hash Filename
37+
("0B6eKvaijfFUDQUUwd21EckhUbWs", "3fedf70df600953d25982bcd13d91ba2", "WIDER_train.zip"),
38+
("0B6eKvaijfFUDd3dIRmpvSk8tLUk", "dfa7d7e790efa35df3788964cf0bbaea", "WIDER_val.zip"),
39+
("0B6eKvaijfFUDbW4tdGpaYjgzZkU", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip")
40+
]
41+
ANNOTATIONS_FILE = (
42+
"http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/support/bbx_annotation/wider_face_split.zip",
43+
"0e3767bcf0e326556d407bf5bff5d27c",
44+
"wider_face_split.zip"
45+
)
46+
47+
def __init__(
48+
self,
49+
root: str,
50+
split: str = "train",
51+
transform: Optional[Callable] = None,
52+
target_transform: Optional[Callable] = None,
53+
download: bool = False,
54+
) -> None:
55+
super(WIDERFace, self).__init__(root=os.path.join(root, self.BASE_FOLDER),
56+
transform=transform,
57+
target_transform=target_transform)
58+
# check arguments
59+
self.split = verify_str_arg(split, "split", ("train", "val", "test"))
60+
61+
if download:
62+
self.download()
63+
64+
if not self._check_integrity():
65+
raise RuntimeError("Dataset not found or corrupted. " +
66+
"You can use download=True to download and prepare it")
67+
68+
self.img_info: List[Dict[str, Union[str, Dict[str, torch.Tensor]]]] = []
69+
if self.split in ("train", "val"):
70+
self.parse_train_val_annotations_file()
71+
else:
72+
self.parse_test_annotations_file()
73+
74+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
75+
"""
76+
Args:
77+
index (int): Index
78+
79+
Returns:
80+
tuple: (image, target) where target is a dict of annotations for all faces in the image.
81+
target=None for the test split.
82+
"""
83+
84+
# stay consistent with other datasets and return a PIL Image
85+
img = Image.open(self.img_info[index]["img_path"])
86+
87+
if self.transform is not None:
88+
img = self.transform(img)
89+
90+
target = None if self.split == "test" else self.img_info[index]["annotations"]
91+
if self.target_transform is not None:
92+
target = self.target_transform(target)
93+
94+
return img, target
95+
96+
def __len__(self) -> int:
97+
return len(self.img_info)
98+
99+
def extra_repr(self) -> str:
100+
lines = ["Split: {split}"]
101+
return '\n'.join(lines).format(**self.__dict__)
102+
103+
def parse_train_val_annotations_file(self) -> None:
104+
filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt"
105+
filepath = os.path.join(self.root, "wider_face_split", filename)
106+
107+
with open(filepath, "r") as f:
108+
lines = f.readlines()
109+
file_name_line, num_boxes_line, box_annotation_line = True, False, False
110+
num_boxes, box_counter = 0, 0
111+
labels = []
112+
for line in lines:
113+
line = line.rstrip()
114+
if file_name_line:
115+
img_path = os.path.join(self.root, "WIDER_" + self.split, "images", line)
116+
img_path = abspath(expanduser(img_path))
117+
file_name_line = False
118+
num_boxes_line = True
119+
elif num_boxes_line:
120+
num_boxes = int(line)
121+
num_boxes_line = False
122+
box_annotation_line = True
123+
elif box_annotation_line:
124+
box_counter += 1
125+
line_split = line.split(" ")
126+
line_values = [int(x) for x in line_split]
127+
labels.append(line_values)
128+
if box_counter >= num_boxes:
129+
box_annotation_line = False
130+
file_name_line = True
131+
labels_tensor = torch.tensor(labels)
132+
self.img_info.append({
133+
"img_path": img_path,
134+
"annotations": {"bbox": labels_tensor[:, 0:4], # x, y, width, height
135+
"blur": labels_tensor[:, 4],
136+
"expression": labels_tensor[:, 5],
137+
"illumination": labels_tensor[:, 6],
138+
"occlusion": labels_tensor[:, 7],
139+
"pose": labels_tensor[:, 8],
140+
"invalid": labels_tensor[:, 9]}
141+
})
142+
box_counter = 0
143+
labels.clear()
144+
else:
145+
raise RuntimeError("Error parsing annotation file {}".format(filepath))
146+
147+
def parse_test_annotations_file(self) -> None:
148+
filepath = os.path.join(self.root, "wider_face_split", "wider_face_test_filelist.txt")
149+
filepath = abspath(expanduser(filepath))
150+
with open(filepath, "r") as f:
151+
lines = f.readlines()
152+
for line in lines:
153+
line = line.rstrip()
154+
img_path = os.path.join(self.root, "WIDER_test", "images", line)
155+
img_path = abspath(expanduser(img_path))
156+
self.img_info.append({"img_path": img_path})
157+
158+
def _check_integrity(self) -> bool:
159+
# Allow original archive to be deleted (zip). Only need the extracted images
160+
all_files = self.FILE_LIST.copy()
161+
all_files.append(self.ANNOTATIONS_FILE)
162+
for (_, md5, filename) in all_files:
163+
file, ext = os.path.splitext(filename)
164+
extracted_dir = os.path.join(self.root, file)
165+
if not os.path.exists(extracted_dir):
166+
return False
167+
return True
168+
169+
def download(self) -> None:
170+
if self._check_integrity():
171+
print('Files already downloaded and verified')
172+
return
173+
174+
# download and extract image data
175+
for (file_id, md5, filename) in self.FILE_LIST:
176+
download_file_from_google_drive(file_id, self.root, filename, md5)
177+
filepath = os.path.join(self.root, filename)
178+
extract_archive(filepath)
179+
180+
# download and extract annotation files
181+
download_and_extract_archive(url=self.ANNOTATIONS_FILE[0],
182+
download_root=self.root,
183+
md5=self.ANNOTATIONS_FILE[1])

0 commit comments

Comments
 (0)