Skip to content

Commit 62e3fbd

Browse files
authored
add typehints for torchvision.datasets.phototour (#2531)
1 parent 1a6148d commit 62e3fbd

File tree

1 file changed

+21
-18
lines changed

1 file changed

+21
-18
lines changed

torchvision/datasets/phototour.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import numpy as np
33
from PIL import Image
4+
from typing import Any, Callable, List, Optional, Tuple, Union
45

56
import torch
67
from .vision import VisionDataset
@@ -54,26 +55,28 @@ class PhotoTour(VisionDataset):
5455
'fdd9152f138ea5ef2091746689176414'
5556
],
5657
}
57-
mean = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437,
58-
'notredame_harris': 0.4854, 'yosemite_harris': 0.4844, 'liberty_harris': 0.4437}
59-
std = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019,
60-
'notredame_harris': 0.1864, 'yosemite_harris': 0.1818, 'liberty_harris': 0.2019}
58+
means = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437,
59+
'notredame_harris': 0.4854, 'yosemite_harris': 0.4844, 'liberty_harris': 0.4437}
60+
stds = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019,
61+
'notredame_harris': 0.1864, 'yosemite_harris': 0.1818, 'liberty_harris': 0.2019}
6162
lens = {'notredame': 468159, 'yosemite': 633587, 'liberty': 450092,
6263
'liberty_harris': 379587, 'yosemite_harris': 450912, 'notredame_harris': 325295}
6364
image_ext = 'bmp'
6465
info_file = 'info.txt'
6566
matches_files = 'm50_100000_100000_0.txt'
6667

67-
def __init__(self, root, name, train=True, transform=None, download=False):
68+
def __init__(
69+
self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False
70+
) -> None:
6871
super(PhotoTour, self).__init__(root, transform=transform)
6972
self.name = name
7073
self.data_dir = os.path.join(self.root, name)
7174
self.data_down = os.path.join(self.root, '{}.zip'.format(name))
7275
self.data_file = os.path.join(self.root, '{}.pt'.format(name))
7376

7477
self.train = train
75-
self.mean = self.mean[name]
76-
self.std = self.std[name]
78+
self.mean = self.means[name]
79+
self.std = self.stds[name]
7780

7881
if download:
7982
self.download()
@@ -85,7 +88,7 @@ def __init__(self, root, name, train=True, transform=None, download=False):
8588
# load the serialized data
8689
self.data, self.labels, self.matches = torch.load(self.data_file)
8790

88-
def __getitem__(self, index):
91+
def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]:
8992
"""
9093
Args:
9194
index (int): Index
@@ -105,18 +108,18 @@ def __getitem__(self, index):
105108
data2 = self.transform(data2)
106109
return data1, data2, m[2]
107110

108-
def __len__(self):
111+
def __len__(self) -> int:
109112
if self.train:
110113
return self.lens[self.name]
111114
return len(self.matches)
112115

113-
def _check_datafile_exists(self):
116+
def _check_datafile_exists(self) -> bool:
114117
return os.path.exists(self.data_file)
115118

116-
def _check_downloaded(self):
119+
def _check_downloaded(self) -> bool:
117120
return os.path.exists(self.data_dir)
118121

119-
def download(self):
122+
def download(self) -> None:
120123
if self._check_datafile_exists():
121124
print('# Found cached data {}'.format(self.data_file))
122125
return
@@ -150,20 +153,20 @@ def download(self):
150153
with open(self.data_file, 'wb') as f:
151154
torch.save(dataset, f)
152155

153-
def extra_repr(self):
156+
def extra_repr(self) -> str:
154157
return "Split: {}".format("Train" if self.train is True else "Test")
155158

156159

157-
def read_image_file(data_dir, image_ext, n):
160+
def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor:
158161
"""Return a Tensor containing the patches
159162
"""
160163

161-
def PIL2array(_img):
164+
def PIL2array(_img: Image.Image) -> np.ndarray:
162165
"""Convert PIL image type to numpy 2D array
163166
"""
164167
return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64)
165168

166-
def find_files(_data_dir, _image_ext):
169+
def find_files(_data_dir: str, _image_ext: str) -> List[str]:
167170
"""Return a list with the file names of the images containing the patches
168171
"""
169172
files = []
@@ -185,7 +188,7 @@ def find_files(_data_dir, _image_ext):
185188
return torch.ByteTensor(np.array(patches[:n]))
186189

187190

188-
def read_info_file(data_dir, info_file):
191+
def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
189192
"""Return a Tensor containing the list of labels
190193
Read the file and keep only the ID of the 3D point.
191194
"""
@@ -195,7 +198,7 @@ def read_info_file(data_dir, info_file):
195198
return torch.LongTensor(labels)
196199

197200

198-
def read_matches_files(data_dir, matches_file):
201+
def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
199202
"""Return a Tensor containing the ground truth matches
200203
Read the file and keep only 3D point ID.
201204
Matches are represented with a 1, non matches with a 0.

0 commit comments

Comments
 (0)