From eea25cc2e2ef38fe0084348c7577fd121263b4b8 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 31 Dec 2021 10:40:02 +0100 Subject: [PATCH 1/3] resolve redirection for HTTP resources --- .../prototype/datasets/utils/_resource.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index 94603bfc81e..35a0dd210f8 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -21,6 +21,8 @@ extract_archive, _decompress, download_file_from_google_drive, + _get_redirect_url, + _get_google_drive_file_id, ) @@ -141,8 +143,40 @@ def __init__( super().__init__(file_name=file_name or pathlib.Path(urlparse(url).path).name, **kwargs) self.url = url self.mirrors = mirrors + self._resolved = False + + @property + def resolved(self) -> bool: + return self._resolved + + def resolve(self) -> Optional[OnlineResource]: + redirect_url = _get_redirect_url(self.url) + if redirect_url == self.url: + self._resolved = True + return self + + meta = { + attr.lstrip("_"): getattr(self, attr) + for attr in ( + "file_name", + "sha256", + "_preprocess", + "_loader", + ) + } + + gdrive_id = _get_google_drive_file_id(redirect_url) + if gdrive_id: + return GDriveResource(gdrive_id, **meta) + + http_resource = HttpResource(redirect_url, **meta) + http_resource._resolved = True + return http_resource def _download(self, root: pathlib.Path) -> None: + if not self.resolved: + return self.resolve()._download(root) + for url in itertools.chain((self.url,), self.mirrors or ()): try: download_url(url, str(root), filename=self.file_name, md5=None) From d47990f82f493ae1d18ae3bd83a08dbfaba67a5c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 3 Jan 2022 11:27:18 +0100 Subject: [PATCH 2/3] appease mypy --- torchvision/prototype/datasets/utils/_resource.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index 35a0dd210f8..a2643d07747 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -3,7 +3,7 @@ import itertools import pathlib import warnings -from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn +from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn, cast from urllib.parse import urlparse from torchdata.datapipes.iter import ( @@ -175,7 +175,7 @@ def resolve(self) -> Optional[OnlineResource]: def _download(self, root: pathlib.Path) -> None: if not self.resolved: - return self.resolve()._download(root) + return cast(OnlineResource, self.resolve())._download(root) for url in itertools.chain((self.url,), self.mirrors or ()): try: From c372606b96f80fe57c356f2001110e48715af342 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 14 Jan 2022 13:41:45 +0100 Subject: [PATCH 3/3] address review --- torchvision/prototype/datasets/utils/_resource.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index a2643d07747..d7bf9fc4b18 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -3,7 +3,7 @@ import itertools import pathlib import warnings -from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn, cast +from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn from urllib.parse import urlparse from torchdata.datapipes.iter import ( @@ -145,11 +145,10 @@ def __init__( self.mirrors = mirrors self._resolved = False - @property - def resolved(self) -> bool: - return self._resolved + def resolve(self) -> OnlineResource: + if self._resolved: + return self - def resolve(self) -> Optional[OnlineResource]: redirect_url = _get_redirect_url(self.url) if redirect_url == self.url: self._resolved = True @@ -174,8 +173,8 @@ def resolve(self) -> Optional[OnlineResource]: return http_resource def _download(self, root: pathlib.Path) -> None: - if not self.resolved: - return cast(OnlineResource, self.resolve())._download(root) + if not self._resolved: + return self.resolve()._download(root) for url in itertools.chain((self.url,), self.mirrors or ()): try: