Skip to content

Commit 7c1ed41

Browse files
authored
add typehints for torchvision.datasets.voc (#2537)
1 parent 0acbf66 commit 7c1ed41

File tree

1 file changed

+29
-24
lines changed

1 file changed

+29
-24
lines changed

torchvision/datasets/voc.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .vision import VisionDataset
55
import xml.etree.ElementTree as ET
66
from PIL import Image
7+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
78
from .utils import download_url, check_integrity, verify_str_arg
89

910
DATASET_YEAR_DICT = {
@@ -70,14 +71,16 @@ class VOCSegmentation(VisionDataset):
7071
and returns a transformed version.
7172
"""
7273

73-
def __init__(self,
74-
root,
75-
year='2012',
76-
image_set='train',
77-
download=False,
78-
transform=None,
79-
target_transform=None,
80-
transforms=None):
74+
def __init__(
75+
self,
76+
root: str,
77+
year: str = "2012",
78+
image_set: str = "train",
79+
download: bool = False,
80+
transform: Optional[Callable] = None,
81+
target_transform: Optional[Callable] = None,
82+
transforms: Optional[Callable] = None,
83+
):
8184
super(VOCSegmentation, self).__init__(root, transforms, transform, target_transform)
8285
self.year = year
8386
if year == "2007" and image_set == "test":
@@ -112,7 +115,7 @@ def __init__(self,
112115
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
113116
assert (len(self.images) == len(self.masks))
114117

115-
def __getitem__(self, index):
118+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
116119
"""
117120
Args:
118121
index (int): Index
@@ -128,7 +131,7 @@ def __getitem__(self, index):
128131

129132
return img, target
130133

131-
def __len__(self):
134+
def __len__(self) -> int:
132135
return len(self.images)
133136

134137

@@ -151,14 +154,16 @@ class VOCDetection(VisionDataset):
151154
and returns a transformed version.
152155
"""
153156

154-
def __init__(self,
155-
root,
156-
year='2012',
157-
image_set='train',
158-
download=False,
159-
transform=None,
160-
target_transform=None,
161-
transforms=None):
157+
def __init__(
158+
self,
159+
root: str,
160+
year: str = "2012",
161+
image_set: str = "train",
162+
download: bool = False,
163+
transform: Optional[Callable] = None,
164+
target_transform: Optional[Callable] = None,
165+
transforms: Optional[Callable] = None,
166+
):
162167
super(VOCDetection, self).__init__(root, transforms, transform, target_transform)
163168
self.year = year
164169
if year == "2007" and image_set == "test":
@@ -194,7 +199,7 @@ def __init__(self,
194199
self.annotations = [os.path.join(annotation_dir, x + ".xml") for x in file_names]
195200
assert (len(self.images) == len(self.annotations))
196201

197-
def __getitem__(self, index):
202+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
198203
"""
199204
Args:
200205
index (int): Index
@@ -211,14 +216,14 @@ def __getitem__(self, index):
211216

212217
return img, target
213218

214-
def __len__(self):
219+
def __len__(self) -> int:
215220
return len(self.images)
216221

217-
def parse_voc_xml(self, node):
218-
voc_dict = {}
222+
def parse_voc_xml(self, node: ET.Element) -> Dict[str, Any]:
223+
voc_dict: Dict[str, Any] = {}
219224
children = list(node)
220225
if children:
221-
def_dic = collections.defaultdict(list)
226+
def_dic: Dict[str, Any] = collections.defaultdict(list)
222227
for dc in map(self.parse_voc_xml, children):
223228
for ind, v in dc.items():
224229
def_dic[ind].append(v)
@@ -236,7 +241,7 @@ def parse_voc_xml(self, node):
236241
return voc_dict
237242

238243

239-
def download_extract(url, root, filename, md5):
244+
def download_extract(url: str, root: str, filename: str, md5: str) -> None:
240245
download_url(url, root, filename, md5)
241246
with tarfile.open(os.path.join(root, filename), "r") as tar:
242247
tar.extractall(path=root)

0 commit comments

Comments
 (0)