Skip to content

add typehints for torchvision.datasets.phototour #2531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 3, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 21 additions & 18 deletions torchvision/datasets/phototour.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import numpy as np
from PIL import Image
from typing import Any, Callable, List, Optional, Tuple, Union

import torch
from .vision import VisionDataset
Expand Down Expand Up @@ -54,26 +55,28 @@ class PhotoTour(VisionDataset):
'fdd9152f138ea5ef2091746689176414'
],
}
mean = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437,
'notredame_harris': 0.4854, 'yosemite_harris': 0.4844, 'liberty_harris': 0.4437}
std = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019,
'notredame_harris': 0.1864, 'yosemite_harris': 0.1818, 'liberty_harris': 0.2019}
means = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is in theory a BC-breaking change.

I think it would be better to avoid renaming this in this PR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it BC breaking? I mean after the assignment (L75-76) the dicts are no longer accessible. This only matters if one uses the dicts from the class directly PhotoTour.mean or PhotoTour.std.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I just saw this, yeah, this was most probably an oversight in the original implementation, good catch!

'notredame_harris': 0.4854, 'yosemite_harris': 0.4844, 'liberty_harris': 0.4437}
stds = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019,
'notredame_harris': 0.1864, 'yosemite_harris': 0.1818, 'liberty_harris': 0.2019}
lens = {'notredame': 468159, 'yosemite': 633587, 'liberty': 450092,
'liberty_harris': 379587, 'yosemite_harris': 450912, 'notredame_harris': 325295}
image_ext = 'bmp'
info_file = 'info.txt'
matches_files = 'm50_100000_100000_0.txt'

def __init__(self, root, name, train=True, transform=None, download=False):
def __init__(
self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False
) -> None:
super(PhotoTour, self).__init__(root, transform=transform)
self.name = name
self.data_dir = os.path.join(self.root, name)
self.data_down = os.path.join(self.root, '{}.zip'.format(name))
self.data_file = os.path.join(self.root, '{}.pt'.format(name))

self.train = train
self.mean = self.mean[name]
self.std = self.std[name]
self.mean = self.means[name]
self.std = self.stds[name]

if download:
self.download()
Expand All @@ -85,7 +88,7 @@ def __init__(self, root, name, train=True, transform=None, download=False):
# load the serialized data
self.data, self.labels, self.matches = torch.load(self.data_file)

def __getitem__(self, index):
def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]:
"""
Args:
index (int): Index
Expand All @@ -105,18 +108,18 @@ def __getitem__(self, index):
data2 = self.transform(data2)
return data1, data2, m[2]

def __len__(self):
def __len__(self) -> int:
if self.train:
return self.lens[self.name]
return len(self.matches)

def _check_datafile_exists(self):
def _check_datafile_exists(self) -> bool:
return os.path.exists(self.data_file)

def _check_downloaded(self):
def _check_downloaded(self) -> bool:
return os.path.exists(self.data_dir)

def download(self):
def download(self) -> None:
if self._check_datafile_exists():
print('# Found cached data {}'.format(self.data_file))
return
Expand Down Expand Up @@ -150,20 +153,20 @@ def download(self):
with open(self.data_file, 'wb') as f:
torch.save(dataset, f)

def extra_repr(self):
def extra_repr(self) -> str:
return "Split: {}".format("Train" if self.train is True else "Test")


def read_image_file(data_dir, image_ext, n):
def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor:
"""Return a Tensor containing the patches
"""

def PIL2array(_img):
def PIL2array(_img: Image.Image) -> np.ndarray:
"""Convert PIL image type to numpy 2D array
"""
return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64)

def find_files(_data_dir, _image_ext):
def find_files(_data_dir: str, _image_ext: str) -> List[str]:
"""Return a list with the file names of the images containing the patches
"""
files = []
Expand All @@ -185,7 +188,7 @@ def find_files(_data_dir, _image_ext):
return torch.ByteTensor(np.array(patches[:n]))


def read_info_file(data_dir, info_file):
def read_info_file(data_dir: str, info_file: str) -> torch.Tensor:
"""Return a Tensor containing the list of labels
Read the file and keep only the ID of the 3D point.
"""
Expand All @@ -195,7 +198,7 @@ def read_info_file(data_dir, info_file):
return torch.LongTensor(labels)


def read_matches_files(data_dir, matches_file):
def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor:
"""Return a Tensor containing the ground truth matches
Read the file and keep only 3D point ID.
Matches are represented with a 1, non matches with a 0.
Expand Down