diff --git a/torchvision/datasets/lfw.py b/torchvision/datasets/lfw.py index a25765d5725..7a5aa45aa4d 100644 --- a/torchvision/datasets/lfw.py +++ b/torchvision/datasets/lfw.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from PIL import Image @@ -38,7 +38,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, - ): + ) -> None: super().__init__(os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform) self.image_set = verify_str_arg(image_set.lower(), "image_set", self.file_dict.keys()) @@ -62,7 +62,7 @@ def _loader(self, path: str) -> Image.Image: img = Image.open(f) return img.convert("RGB") - def _check_integrity(self): + def _check_integrity(self) -> bool: st1 = check_integrity(os.path.join(self.root, self.filename), self.md5) st2 = check_integrity(os.path.join(self.root, self.labels_file), self.checksums[self.labels_file]) if not st1 or not st2: @@ -71,7 +71,7 @@ def _check_integrity(self): return check_integrity(os.path.join(self.root, self.names), self.checksums[self.names]) return True - def download(self): + def download(self) -> None: if self._check_integrity(): print("Files already downloaded and verified") return @@ -81,13 +81,13 @@ def download(self): if self.view == "people": download_url(f"{self.download_url_prefix}{self.names}", self.root) - def _get_path(self, identity, no): + def _get_path(self, identity: str, no: Union[int, str]) -> str: return os.path.join(self.images_dir, identity, f"{identity}_{int(no):04d}.jpg") def extra_repr(self) -> str: return f"Alignment: {self.image_set}\nSplit: {self.split}" - def __len__(self): + def __len__(self) -> int: return len(self.data) @@ -119,13 +119,13 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, - ): + ) -> None: super().__init__(root, split, image_set, "people", transform, target_transform, download) self.class_to_idx = self._get_classes() self.data, self.targets = self._get_people() - def _get_people(self): + def _get_people(self) -> Tuple[List[str], List[int]]: data, targets = [], [] with open(os.path.join(self.root, self.labels_file)) as f: lines = f.readlines() @@ -143,7 +143,7 @@ def _get_people(self): return data, targets - def _get_classes(self): + def _get_classes(self) -> Dict[str, int]: with open(os.path.join(self.root, self.names)) as f: lines = f.readlines() names = [line.strip().split()[0] for line in lines] @@ -201,12 +201,12 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, - ): + ) -> None: super().__init__(root, split, image_set, "pairs", transform, target_transform, download) self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir) - def _get_pairs(self, images_dir): + def _get_pairs(self, images_dir: str) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]], List[int]]: pair_names, data, targets = [], [], [] with open(os.path.join(self.root, self.labels_file)) as f: lines = f.readlines()