Skip to content

Commit d40ba15

Browse files
authored
Merge branch 'main' into models/convnext_graduation
2 parents 0948f50 + d5a22a8 commit d40ba15

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

torchvision/prototype/datasets/utils/_resource.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
extract_archive,
2121
_decompress,
2222
download_file_from_google_drive,
23+
_get_redirect_url,
24+
_get_google_drive_file_id,
2325
)
2426

2527

@@ -134,9 +136,41 @@ def __init__(
134136
super().__init__(file_name=file_name or pathlib.Path(urlparse(url).path).name, **kwargs)
135137
self.url = url
136138
self.mirrors = mirrors
139+
self._resolved = False
140+
141+
def resolve(self) -> OnlineResource:
142+
if self._resolved:
143+
return self
144+
145+
redirect_url = _get_redirect_url(self.url)
146+
if redirect_url == self.url:
147+
self._resolved = True
148+
return self
149+
150+
meta = {
151+
attr.lstrip("_"): getattr(self, attr)
152+
for attr in (
153+
"file_name",
154+
"sha256",
155+
"_preprocess",
156+
"_loader",
157+
)
158+
}
159+
160+
gdrive_id = _get_google_drive_file_id(redirect_url)
161+
if gdrive_id:
162+
return GDriveResource(gdrive_id, **meta)
163+
164+
http_resource = HttpResource(redirect_url, **meta)
165+
http_resource._resolved = True
166+
return http_resource
137167

138168
def _download(self, root: pathlib.Path) -> None:
169+
if not self._resolved:
170+
return self.resolve()._download(root)
171+
139172
for url in itertools.chain((self.url,), self.mirrors):
173+
140174
try:
141175
download_url(url, str(root), filename=self.file_name, md5=None)
142176
# TODO: make this more precise

0 commit comments

Comments
 (0)