Skip to content

Commit 151e162

Browse files
authored
fix HttpResource.resolve() with preprocessing (#5669)
* fix HttpResource.resolve() with preprocess set * fix README * add safe guard for invalid str inputs
1 parent 647016b commit 151e162

File tree

8 files changed

+83
-24
lines changed

8 files changed

+83
-24
lines changed

test/test_prototype_datasets_utils.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
from datasets_utils import make_fake_flo_file
77
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
8+
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource
89
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile
910

1011

@@ -45,3 +46,58 @@ def test_read_flo(tmpdir):
4546
expected = torch.from_numpy(read_flo_ref(path).astype("f4", copy=False))
4647

4748
torch.testing.assert_close(actual, expected)
49+
50+
51+
class TestHttpResource:
52+
def test_resolve_to_http(self, mocker):
53+
file_name = "data.tar"
54+
original_url = f"http://downloads.pytorch.org/{file_name}"
55+
56+
redirected_url = original_url.replace("http", "https")
57+
58+
sha256_sentinel = "sha256_sentinel"
59+
60+
def preprocess_sentinel(path):
61+
return path
62+
63+
original_resource = HttpResource(
64+
original_url,
65+
sha256=sha256_sentinel,
66+
preprocess=preprocess_sentinel,
67+
)
68+
69+
mocker.patch("torchvision.prototype.datasets.utils._resource._get_redirect_url", return_value=redirected_url)
70+
redirected_resource = original_resource.resolve()
71+
72+
assert isinstance(redirected_resource, HttpResource)
73+
assert redirected_resource.url == redirected_url
74+
assert redirected_resource.file_name == file_name
75+
assert redirected_resource.sha256 == sha256_sentinel
76+
assert redirected_resource._preprocess is preprocess_sentinel
77+
78+
def test_resolve_to_gdrive(self, mocker):
79+
file_name = "data.tar"
80+
original_url = f"http://downloads.pytorch.org/{file_name}"
81+
82+
id_sentinel = "id-sentinel"
83+
redirected_url = f"https://drive.google.com/file/d/{id_sentinel}/view"
84+
85+
sha256_sentinel = "sha256_sentinel"
86+
87+
def preprocess_sentinel(path):
88+
return path
89+
90+
original_resource = HttpResource(
91+
original_url,
92+
sha256=sha256_sentinel,
93+
preprocess=preprocess_sentinel,
94+
)
95+
96+
mocker.patch("torchvision.prototype.datasets.utils._resource._get_redirect_url", return_value=redirected_url)
97+
redirected_resource = original_resource.resolve()
98+
99+
assert isinstance(redirected_resource, GDriveResource)
100+
assert redirected_resource.id == id_sentinel
101+
assert redirected_resource.file_name == file_name
102+
assert redirected_resource.sha256 == sha256_sentinel
103+
assert redirected_resource._preprocess is preprocess_sentinel

torchvision/prototype/datasets/_builtin/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ To generate the `$NAME.categories` file, run `python -m torchvision.prototype.da
231231
### What if a resource file forms an I/O bottleneck?
232232

233233
In general, we are ok with small performance hits of iterating archives rather than their extracted content. However, if
234-
the performance hit becomes significant, the archives can still be decompressed or extracted. To do this, the
235-
`decompress: bool` and `extract: bool` flags can be used for every `OnlineResource` individually. For more complex
236-
cases, each resource also accepts a `preprocess` callable that gets passed a `pathlib.Path` of the raw file and should
237-
return `pathlib.Path` of the preprocessed file or folder.
234+
the performance hit becomes significant, the archives can still be preprocessed. `OnlineResource` accepts the
235+
`preprocess` parameter that can be a `Callable[[pathlib.Path], pathlib.Path]` where the input points to the file to be
236+
preprocessed and the return value should be the result of the preprocessing to load. For convenience, `preprocess` also
237+
accepts `"decompress"` and `"extract"` to handle these common scenarios.

torchvision/prototype/datasets/_builtin/caltech.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
3232
images = HttpResource(
3333
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
3434
sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926",
35-
decompress=True,
35+
preprocess="decompress",
3636
)
3737
anns = HttpResource(
3838
"http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",

torchvision/prototype/datasets/_builtin/cub200.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,29 +51,29 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
5151
archive = HttpResource(
5252
"http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz",
5353
sha256="0c685df5597a8b24909f6a7c9db6d11e008733779a671760afef78feb49bf081",
54-
decompress=True,
54+
preprocess="decompress",
5555
)
5656
segmentations = HttpResource(
5757
"http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/segmentations.tgz",
5858
sha256="dc77f6cffea0cbe2e41d4201115c8f29a6320ecb04fffd2444f51b8066e4b84f",
59-
decompress=True,
59+
preprocess="decompress",
6060
)
6161
return [archive, segmentations]
6262
else: # config.year == "2010"
6363
split = HttpResource(
6464
"http://www.vision.caltech.edu/visipedia-data/CUB-200/lists.tgz",
6565
sha256="aeacbd5e3539ae84ea726e8a266a9a119c18f055cd80f3836d5eb4500b005428",
66-
decompress=True,
66+
preprocess="decompress",
6767
)
6868
images = HttpResource(
6969
"http://www.vision.caltech.edu/visipedia-data/CUB-200/images.tgz",
7070
sha256="2a6d2246bbb9778ca03aa94e2e683ccb4f8821a36b7f235c0822e659d60a803e",
71-
decompress=True,
71+
preprocess="decompress",
7272
)
7373
anns = HttpResource(
7474
"http://www.vision.caltech.edu/visipedia-data/CUB-200/annotations.tgz",
7575
sha256="c17b7841c21a66aa44ba8fe92369cc95dfc998946081828b1d7b8a4b716805c1",
76-
decompress=True,
76+
preprocess="decompress",
7777
)
7878
return [split, images, anns]
7979

torchvision/prototype/datasets/_builtin/dtd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
4949
archive = HttpResource(
5050
"https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz",
5151
sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205",
52-
decompress=True,
52+
preprocess="decompress",
5353
)
5454
return [archive]
5555

torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]:
4040
images = HttpResource(
4141
"https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz",
4242
sha256="67195c5e1c01f1ab5f9b6a5d22b8c27a580d896ece458917e61d459337fa318d",
43-
decompress=True,
43+
preprocess="decompress",
4444
)
4545
anns = HttpResource(
4646
"https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz",
4747
sha256="52425fb6de5c424942b7626b428656fcbd798db970a937df61750c0f1d358e91",
48-
decompress=True,
48+
preprocess="decompress",
4949
)
5050
return [images, anns]
5151

torchvision/prototype/datasets/_builtin/pcam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _make_info(self) -> DatasetInfo:
9191

9292
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
9393
return [ # = [images resource, targets resource]
94-
GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, decompress=True)
94+
GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, preprocess="decompress")
9595
for file_name, gdrive_id, sha256 in self._RESOURCES[config.split]
9696
]
9797

torchvision/prototype/datasets/utils/_resource.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
_get_redirect_url,
2424
_get_google_drive_file_id,
2525
)
26+
from typing_extensions import Literal
2627

2728

2829
class OnlineResource(abc.ABC):
@@ -31,19 +32,22 @@ def __init__(
3132
*,
3233
file_name: str,
3334
sha256: Optional[str] = None,
34-
decompress: bool = False,
35-
extract: bool = False,
35+
preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], pathlib.Path]]] = None,
3636
) -> None:
3737
self.file_name = file_name
3838
self.sha256 = sha256
3939

40-
self._preprocess: Optional[Callable[[pathlib.Path], pathlib.Path]]
41-
if extract:
42-
self._preprocess = self._extract
43-
elif decompress:
44-
self._preprocess = self._decompress
45-
else:
46-
self._preprocess = None
40+
if isinstance(preprocess, str):
41+
if preprocess == "decompress":
42+
preprocess = self._decompress
43+
elif preprocess == "extract":
44+
preprocess = self._extract
45+
else:
46+
raise ValueError(
47+
f"Only `'decompress'` or `'extract'` are valid if `preprocess` is passed as string,"
48+
f"but got {preprocess} instead."
49+
)
50+
self._preprocess = preprocess
4751

4852
@staticmethod
4953
def _extract(file: pathlib.Path) -> pathlib.Path:
@@ -163,7 +167,6 @@ def resolve(self) -> OnlineResource:
163167
"file_name",
164168
"sha256",
165169
"_preprocess",
166-
"_loader",
167170
)
168171
}
169172

0 commit comments

Comments
 (0)