Skip to content

Commit 1a6148d

Browse files
authored
add typehints for torchvision.datasets.svhn (#2539)
1 parent 7c1ed41 commit 1a6148d

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

torchvision/datasets/svhn.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import os.path
55
import numpy as np
6+
from typing import Any, Callable, Optional, Tuple
67
from .utils import download_url, check_integrity, verify_str_arg
78

89

@@ -39,8 +40,14 @@ class SVHN(VisionDataset):
3940
'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
4041
"extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}
4142

42-
def __init__(self, root, split='train', transform=None, target_transform=None,
43-
download=False):
43+
def __init__(
44+
self,
45+
root: str,
46+
split: str = "train",
47+
transform: Optional[Callable] = None,
48+
target_transform: Optional[Callable] = None,
49+
download: bool = False,
50+
) -> None:
4451
super(SVHN, self).__init__(root, transform=transform,
4552
target_transform=target_transform)
4653
self.split = verify_str_arg(split, "split", tuple(self.split_list.keys()))
@@ -75,7 +82,7 @@ def __init__(self, root, split='train', transform=None, target_transform=None,
7582
np.place(self.labels, self.labels == 10, 0)
7683
self.data = np.transpose(self.data, (3, 2, 0, 1))
7784

78-
def __getitem__(self, index):
85+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
7986
"""
8087
Args:
8188
index (int): Index
@@ -97,18 +104,18 @@ def __getitem__(self, index):
97104

98105
return img, target
99106

100-
def __len__(self):
107+
def __len__(self) -> int:
101108
return len(self.data)
102109

103-
def _check_integrity(self):
110+
def _check_integrity(self) -> bool:
104111
root = self.root
105112
md5 = self.split_list[self.split][2]
106113
fpath = os.path.join(root, self.filename)
107114
return check_integrity(fpath, md5)
108115

109-
def download(self):
116+
def download(self) -> None:
110117
md5 = self.split_list[self.split][2]
111118
download_url(self.url, self.root, self.filename, md5)
112119

113-
def extra_repr(self):
120+
def extra_repr(self) -> str:
114121
return "Split: {split}".format(**self.__dict__)

0 commit comments

Comments
 (0)