4
4
from .vision import VisionDataset
5
5
import xml .etree .ElementTree as ET
6
6
from PIL import Image
7
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
7
8
from .utils import download_url , check_integrity , verify_str_arg
8
9
9
10
DATASET_YEAR_DICT = {
@@ -70,14 +71,16 @@ class VOCSegmentation(VisionDataset):
70
71
and returns a transformed version.
71
72
"""
72
73
73
- def __init__ (self ,
74
- root ,
75
- year = '2012' ,
76
- image_set = 'train' ,
77
- download = False ,
78
- transform = None ,
79
- target_transform = None ,
80
- transforms = None ):
74
+ def __init__ (
75
+ self ,
76
+ root : str ,
77
+ year : str = "2012" ,
78
+ image_set : str = "train" ,
79
+ download : bool = False ,
80
+ transform : Optional [Callable ] = None ,
81
+ target_transform : Optional [Callable ] = None ,
82
+ transforms : Optional [Callable ] = None ,
83
+ ):
81
84
super (VOCSegmentation , self ).__init__ (root , transforms , transform , target_transform )
82
85
self .year = year
83
86
if year == "2007" and image_set == "test" :
@@ -112,7 +115,7 @@ def __init__(self,
112
115
self .masks = [os .path .join (mask_dir , x + ".png" ) for x in file_names ]
113
116
assert (len (self .images ) == len (self .masks ))
114
117
115
- def __getitem__ (self , index ) :
118
+ def __getitem__ (self , index : int ) -> Tuple [ Any , Any ] :
116
119
"""
117
120
Args:
118
121
index (int): Index
@@ -128,7 +131,7 @@ def __getitem__(self, index):
128
131
129
132
return img , target
130
133
131
- def __len__ (self ):
134
+ def __len__ (self ) -> int :
132
135
return len (self .images )
133
136
134
137
@@ -151,14 +154,16 @@ class VOCDetection(VisionDataset):
151
154
and returns a transformed version.
152
155
"""
153
156
154
- def __init__ (self ,
155
- root ,
156
- year = '2012' ,
157
- image_set = 'train' ,
158
- download = False ,
159
- transform = None ,
160
- target_transform = None ,
161
- transforms = None ):
157
+ def __init__ (
158
+ self ,
159
+ root : str ,
160
+ year : str = "2012" ,
161
+ image_set : str = "train" ,
162
+ download : bool = False ,
163
+ transform : Optional [Callable ] = None ,
164
+ target_transform : Optional [Callable ] = None ,
165
+ transforms : Optional [Callable ] = None ,
166
+ ):
162
167
super (VOCDetection , self ).__init__ (root , transforms , transform , target_transform )
163
168
self .year = year
164
169
if year == "2007" and image_set == "test" :
@@ -194,7 +199,7 @@ def __init__(self,
194
199
self .annotations = [os .path .join (annotation_dir , x + ".xml" ) for x in file_names ]
195
200
assert (len (self .images ) == len (self .annotations ))
196
201
197
- def __getitem__ (self , index ) :
202
+ def __getitem__ (self , index : int ) -> Tuple [ Any , Any ] :
198
203
"""
199
204
Args:
200
205
index (int): Index
@@ -211,14 +216,14 @@ def __getitem__(self, index):
211
216
212
217
return img , target
213
218
214
- def __len__ (self ):
219
+ def __len__ (self ) -> int :
215
220
return len (self .images )
216
221
217
- def parse_voc_xml (self , node ) :
218
- voc_dict = {}
222
+ def parse_voc_xml (self , node : ET . Element ) -> Dict [ str , Any ] :
223
+ voc_dict : Dict [ str , Any ] = {}
219
224
children = list (node )
220
225
if children :
221
- def_dic = collections .defaultdict (list )
226
+ def_dic : Dict [ str , Any ] = collections .defaultdict (list )
222
227
for dc in map (self .parse_voc_xml , children ):
223
228
for ind , v in dc .items ():
224
229
def_dic [ind ].append (v )
@@ -236,7 +241,7 @@ def parse_voc_xml(self, node):
236
241
return voc_dict
237
242
238
243
239
- def download_extract (url , root , filename , md5 ) :
244
+ def download_extract (url : str , root : str , filename : str , md5 : str ) -> None :
240
245
download_url (url , root , filename , md5 )
241
246
with tarfile .open (os .path .join (root , filename ), "r" ) as tar :
242
247
tar .extractall (path = root )
0 commit comments