3
3
import os
4
4
import os .path
5
5
import numpy as np
6
+ from typing import Any , Callable , Optional , Tuple
6
7
from .utils import download_url , check_integrity , verify_str_arg
7
8
8
9
@@ -39,8 +40,14 @@ class SVHN(VisionDataset):
39
40
'extra' : ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat" ,
40
41
"extra_32x32.mat" , "a93ce644f1a588dc4d68dda5feec44a7" ]}
41
42
42
- def __init__ (self , root , split = 'train' , transform = None , target_transform = None ,
43
- download = False ):
43
+ def __init__ (
44
+ self ,
45
+ root : str ,
46
+ split : str = "train" ,
47
+ transform : Optional [Callable ] = None ,
48
+ target_transform : Optional [Callable ] = None ,
49
+ download : bool = False ,
50
+ ) -> None :
44
51
super (SVHN , self ).__init__ (root , transform = transform ,
45
52
target_transform = target_transform )
46
53
self .split = verify_str_arg (split , "split" , tuple (self .split_list .keys ()))
@@ -75,7 +82,7 @@ def __init__(self, root, split='train', transform=None, target_transform=None,
75
82
np .place (self .labels , self .labels == 10 , 0 )
76
83
self .data = np .transpose (self .data , (3 , 2 , 0 , 1 ))
77
84
78
- def __getitem__ (self , index ) :
85
+ def __getitem__ (self , index : int ) -> Tuple [ Any , Any ] :
79
86
"""
80
87
Args:
81
88
index (int): Index
@@ -97,18 +104,18 @@ def __getitem__(self, index):
97
104
98
105
return img , target
99
106
100
- def __len__ (self ):
107
+ def __len__ (self ) -> int :
101
108
return len (self .data )
102
109
103
- def _check_integrity (self ):
110
+ def _check_integrity (self ) -> bool :
104
111
root = self .root
105
112
md5 = self .split_list [self .split ][2 ]
106
113
fpath = os .path .join (root , self .filename )
107
114
return check_integrity (fpath , md5 )
108
115
109
- def download (self ):
116
+ def download (self ) -> None :
110
117
md5 = self .split_list [self .split ][2 ]
111
118
download_url (self .url , self .root , self .filename , md5 )
112
119
113
- def extra_repr (self ):
120
+ def extra_repr (self ) -> str :
114
121
return "Split: {split}" .format (** self .__dict__ )
0 commit comments