From 13a4c6629ff2863974fb5a5b959986076055ac49 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 20 Jan 2022 12:20:18 +0100 Subject: [PATCH 1/5] allow single extension as str in make_dataset --- test/test_datasets_utils.py | 27 +++++++++++++++++++++++++++ torchvision/datasets/folder.py | 12 ++++++++---- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index c3e63fb7f5e..78144b616ad 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -1,12 +1,14 @@ import contextlib import gzip import os +import pathlib import tarfile import zipfile import pytest import torchvision.datasets.utils as utils from torch._utils_internal import get_file_path_2 +from torchvision.datasets.folder import make_dataset from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS @@ -214,5 +216,30 @@ def test_verify_str_arg(self): pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg") +class TestFolderUtils: + @pytest.mark.parametrize( + ("kwargs", "expected_error_msg"), + [ + (dict(is_valid_file=lambda path: pathlib.Path(path).suffix in {".png", ".jpeg"}), "classes c"), + (dict(extensions=".png"), r"classes b, c.*?[.]png"), + (dict(extensions=[".png", ".jpeg"]), "c.*?[.]png, [.]jpeg"), + ], + ) + def test_make_dataset_no_valid(self, tmpdir, kwargs, expected_error_msg): + tmpdir = pathlib.Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "a" / "a.png").touch() + + (tmpdir / "b").mkdir() + (tmpdir / "b" / "b.jpeg").touch() + + (tmpdir / "c").mkdir() + (tmpdir / "c" / "c.unknown").touch() + + with pytest.raises(FileNotFoundError, match=expected_error_msg): + make_dataset(str(tmpdir), **kwargs) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 6c436088de9..3c64fbdff86 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -1,13 +1,14 @@ import os import os.path from typing import Any, Callable, cast, Dict, List, Optional, Tuple +from typing import Collection, Union from PIL import Image from .vision import VisionDataset -def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: +def has_file_allowed_extension(filename: str, extensions: Union[str, Collection[str]]) -> bool: """Checks if a file is an allowed extension. Args: @@ -17,7 +18,7 @@ def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bo Returns: bool: True if the filename ends with one of given extensions """ - return filename.lower().endswith(extensions) + return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions)) def is_image_file(filename: str) -> bool: @@ -48,7 +49,7 @@ def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: def make_dataset( directory: str, class_to_idx: Optional[Dict[str, int]] = None, - extensions: Optional[Tuple[str, ...]] = None, + extensions: Optional[Union[str, Collection[str]]] = None, is_valid_file: Optional[Callable[[str], bool]] = None, ) -> List[Tuple[str, int]]: """Generates a list of samples of a form (path_to_sample, class). @@ -72,8 +73,11 @@ def make_dataset( if extensions is not None: + if isinstance(extensions, str): + extensions = (extensions,) + def is_valid_file(x: str) -> bool: - return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) + return has_file_allowed_extension(x, cast(Collection[str], extensions)) is_valid_file = cast(Callable[[str], bool], is_valid_file) From 81d5b763b604144c1177a73ffa88eff931ad71ec Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 20 Jan 2022 12:37:04 +0100 Subject: [PATCH 2/5] remove test class --- test/test_datasets_utils.py | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 78144b616ad..a5d3d1b7d4a 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -216,29 +216,28 @@ def test_verify_str_arg(self): pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg") -class TestFolderUtils: - @pytest.mark.parametrize( - ("kwargs", "expected_error_msg"), - [ - (dict(is_valid_file=lambda path: pathlib.Path(path).suffix in {".png", ".jpeg"}), "classes c"), - (dict(extensions=".png"), r"classes b, c.*?[.]png"), - (dict(extensions=[".png", ".jpeg"]), "c.*?[.]png, [.]jpeg"), - ], - ) - def test_make_dataset_no_valid(self, tmpdir, kwargs, expected_error_msg): - tmpdir = pathlib.Path(tmpdir) +@pytest.mark.parametrize( + ("kwargs", "expected_error_msg"), + [ + (dict(is_valid_file=lambda path: pathlib.Path(path).suffix in {".png", ".jpeg"}), "classes c"), + (dict(extensions=".png"), r"classes b, c.*?[.]png"), + (dict(extensions=[".png", ".jpeg"]), "c.*?[.]png, [.]jpeg"), + ], +) +def test_make_dataset_no_valid_files(tmpdir, kwargs, expected_error_msg): + tmpdir = pathlib.Path(tmpdir) - (tmpdir / "a").mkdir() - (tmpdir / "a" / "a.png").touch() + (tmpdir / "a").mkdir() + (tmpdir / "a" / "a.png").touch() - (tmpdir / "b").mkdir() - (tmpdir / "b" / "b.jpeg").touch() + (tmpdir / "b").mkdir() + (tmpdir / "b" / "b.jpeg").touch() - (tmpdir / "c").mkdir() - (tmpdir / "c" / "c.unknown").touch() + (tmpdir / "c").mkdir() + (tmpdir / "c" / "c.unknown").touch() - with pytest.raises(FileNotFoundError, match=expected_error_msg): - make_dataset(str(tmpdir), **kwargs) + with pytest.raises(FileNotFoundError, match=expected_error_msg): + make_dataset(str(tmpdir), **kwargs) if __name__ == "__main__": From 831fcf13d789be8452e0dd3a3bee369e4d95aeda Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 20 Jan 2022 14:08:15 +0100 Subject: [PATCH 3/5] remove regex --- test/test_datasets_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index a5d3d1b7d4a..82b2cca3a9d 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -2,6 +2,7 @@ import gzip import os import pathlib +import re import tarfile import zipfile @@ -11,7 +12,6 @@ from torchvision.datasets.folder import make_dataset from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS - TEST_FILE = get_file_path_2( os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg" ) @@ -220,8 +220,8 @@ def test_verify_str_arg(self): ("kwargs", "expected_error_msg"), [ (dict(is_valid_file=lambda path: pathlib.Path(path).suffix in {".png", ".jpeg"}), "classes c"), - (dict(extensions=".png"), r"classes b, c.*?[.]png"), - (dict(extensions=[".png", ".jpeg"]), "c.*?[.]png, [.]jpeg"), + (dict(extensions=".png"), re.escape("classes b, c. Supported extensions are: .png")), + (dict(extensions=[".png", ".jpeg"]), re.escape("c. Supported extensions are: .png, .jpeg")), ], ) def test_make_dataset_no_valid_files(tmpdir, kwargs, expected_error_msg): From fd55a0681f7884148263c9186d12b624b113bfeb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 20 Jan 2022 14:11:03 +0100 Subject: [PATCH 4/5] revert collection to tuple --- test/test_datasets_utils.py | 2 +- torchvision/datasets/folder.py | 13 +++++-------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 82b2cca3a9d..c61004a2d43 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -221,7 +221,7 @@ def test_verify_str_arg(self): [ (dict(is_valid_file=lambda path: pathlib.Path(path).suffix in {".png", ".jpeg"}), "classes c"), (dict(extensions=".png"), re.escape("classes b, c. Supported extensions are: .png")), - (dict(extensions=[".png", ".jpeg"]), re.escape("c. Supported extensions are: .png, .jpeg")), + (dict(extensions=(".png", ".jpeg")), re.escape("c. Supported extensions are: .png, .jpeg")), ], ) def test_make_dataset_no_valid_files(tmpdir, kwargs, expected_error_msg): diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 3c64fbdff86..d5a7e88083b 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -1,14 +1,14 @@ import os import os.path from typing import Any, Callable, cast, Dict, List, Optional, Tuple -from typing import Collection, Union +from typing import Union from PIL import Image from .vision import VisionDataset -def has_file_allowed_extension(filename: str, extensions: Union[str, Collection[str]]) -> bool: +def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool: """Checks if a file is an allowed extension. Args: @@ -49,7 +49,7 @@ def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: def make_dataset( directory: str, class_to_idx: Optional[Dict[str, int]] = None, - extensions: Optional[Union[str, Collection[str]]] = None, + extensions: Optional[Union[str, Tuple[str, ...]]] = None, is_valid_file: Optional[Callable[[str], bool]] = None, ) -> List[Tuple[str, int]]: """Generates a list of samples of a form (path_to_sample, class). @@ -73,11 +73,8 @@ def make_dataset( if extensions is not None: - if isinstance(extensions, str): - extensions = (extensions,) - def is_valid_file(x: str) -> bool: - return has_file_allowed_extension(x, cast(Collection[str], extensions)) + return has_file_allowed_extension(x, extensions) # type: ignore[arg-type] is_valid_file = cast(Callable[[str], bool], is_valid_file) @@ -102,7 +99,7 @@ def is_valid_file(x: str) -> bool: if empty_classes: msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " if extensions is not None: - msg += f"Supported extensions are: {', '.join(extensions)}" + msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" raise FileNotFoundError(msg) return instances From 39cf48fdb97d6b61e29cd6d9915c667fddc51575 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 20 Jan 2022 14:12:46 +0100 Subject: [PATCH 5/5] cleanup --- test/test_datasets_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index c61004a2d43..ec68fd72a5b 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -221,7 +221,7 @@ def test_verify_str_arg(self): [ (dict(is_valid_file=lambda path: pathlib.Path(path).suffix in {".png", ".jpeg"}), "classes c"), (dict(extensions=".png"), re.escape("classes b, c. Supported extensions are: .png")), - (dict(extensions=(".png", ".jpeg")), re.escape("c. Supported extensions are: .png, .jpeg")), + (dict(extensions=(".png", ".jpeg")), re.escape("classes c. Supported extensions are: .png, .jpeg")), ], ) def test_make_dataset_no_valid_files(tmpdir, kwargs, expected_error_msg):