diff --git a/torchvision/datasets/phototour.py b/torchvision/datasets/phototour.py index 47591e3db8c..0a56afc8382 100644 --- a/torchvision/datasets/phototour.py +++ b/torchvision/datasets/phototour.py @@ -1,6 +1,7 @@ import os import numpy as np from PIL import Image +from typing import Any, Callable, List, Optional, Tuple, Union import torch from .vision import VisionDataset @@ -54,17 +55,19 @@ class PhotoTour(VisionDataset): 'fdd9152f138ea5ef2091746689176414' ], } - mean = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437, - 'notredame_harris': 0.4854, 'yosemite_harris': 0.4844, 'liberty_harris': 0.4437} - std = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019, - 'notredame_harris': 0.1864, 'yosemite_harris': 0.1818, 'liberty_harris': 0.2019} + means = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437, + 'notredame_harris': 0.4854, 'yosemite_harris': 0.4844, 'liberty_harris': 0.4437} + stds = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019, + 'notredame_harris': 0.1864, 'yosemite_harris': 0.1818, 'liberty_harris': 0.2019} lens = {'notredame': 468159, 'yosemite': 633587, 'liberty': 450092, 'liberty_harris': 379587, 'yosemite_harris': 450912, 'notredame_harris': 325295} image_ext = 'bmp' info_file = 'info.txt' matches_files = 'm50_100000_100000_0.txt' - def __init__(self, root, name, train=True, transform=None, download=False): + def __init__( + self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False + ) -> None: super(PhotoTour, self).__init__(root, transform=transform) self.name = name self.data_dir = os.path.join(self.root, name) @@ -72,8 +75,8 @@ def __init__(self, root, name, train=True, transform=None, download=False): self.data_file = os.path.join(self.root, '{}.pt'.format(name)) self.train = train - self.mean = self.mean[name] - self.std = self.std[name] + self.mean = self.means[name] + self.std = self.stds[name] if download: self.download() @@ -85,7 +88,7 @@ def __init__(self, root, name, train=True, transform=None, download=False): # load the serialized data self.data, self.labels, self.matches = torch.load(self.data_file) - def __getitem__(self, index): + def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]: """ Args: index (int): Index @@ -105,18 +108,18 @@ def __getitem__(self, index): data2 = self.transform(data2) return data1, data2, m[2] - def __len__(self): + def __len__(self) -> int: if self.train: return self.lens[self.name] return len(self.matches) - def _check_datafile_exists(self): + def _check_datafile_exists(self) -> bool: return os.path.exists(self.data_file) - def _check_downloaded(self): + def _check_downloaded(self) -> bool: return os.path.exists(self.data_dir) - def download(self): + def download(self) -> None: if self._check_datafile_exists(): print('# Found cached data {}'.format(self.data_file)) return @@ -150,20 +153,20 @@ def download(self): with open(self.data_file, 'wb') as f: torch.save(dataset, f) - def extra_repr(self): + def extra_repr(self) -> str: return "Split: {}".format("Train" if self.train is True else "Test") -def read_image_file(data_dir, image_ext, n): +def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor: """Return a Tensor containing the patches """ - def PIL2array(_img): + def PIL2array(_img: Image.Image) -> np.ndarray: """Convert PIL image type to numpy 2D array """ return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64) - def find_files(_data_dir, _image_ext): + def find_files(_data_dir: str, _image_ext: str) -> List[str]: """Return a list with the file names of the images containing the patches """ files = [] @@ -185,7 +188,7 @@ def find_files(_data_dir, _image_ext): return torch.ByteTensor(np.array(patches[:n])) -def read_info_file(data_dir, info_file): +def read_info_file(data_dir: str, info_file: str) -> torch.Tensor: """Return a Tensor containing the list of labels Read the file and keep only the ID of the 3D point. """ @@ -195,7 +198,7 @@ def read_info_file(data_dir, info_file): return torch.LongTensor(labels) -def read_matches_files(data_dir, matches_file): +def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor: """Return a Tensor containing the ground truth matches Read the file and keep only 3D point ID. Matches are represented with a 1, non matches with a 0.