diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index c7fde65468a..294c0c9099b 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -20,6 +20,8 @@ extract_archive, _decompress, download_file_from_google_drive, + _get_redirect_url, + _get_google_drive_file_id, ) @@ -134,9 +136,41 @@ def __init__( super().__init__(file_name=file_name or pathlib.Path(urlparse(url).path).name, **kwargs) self.url = url self.mirrors = mirrors + self._resolved = False + + def resolve(self) -> OnlineResource: + if self._resolved: + return self + + redirect_url = _get_redirect_url(self.url) + if redirect_url == self.url: + self._resolved = True + return self + + meta = { + attr.lstrip("_"): getattr(self, attr) + for attr in ( + "file_name", + "sha256", + "_preprocess", + "_loader", + ) + } + + gdrive_id = _get_google_drive_file_id(redirect_url) + if gdrive_id: + return GDriveResource(gdrive_id, **meta) + + http_resource = HttpResource(redirect_url, **meta) + http_resource._resolved = True + return http_resource def _download(self, root: pathlib.Path) -> None: + if not self._resolved: + return self.resolve()._download(root) + for url in itertools.chain((self.url,), self.mirrors): + try: download_url(url, str(root), filename=self.file_name, md5=None) # TODO: make this more precise