File tree 1 file changed +34
-0
lines changed
torchvision/prototype/datasets/utils 1 file changed +34
-0
lines changed Original file line number Diff line number Diff line change 20
20
extract_archive ,
21
21
_decompress ,
22
22
download_file_from_google_drive ,
23
+ _get_redirect_url ,
24
+ _get_google_drive_file_id ,
23
25
)
24
26
25
27
@@ -134,9 +136,41 @@ def __init__(
134
136
super ().__init__ (file_name = file_name or pathlib .Path (urlparse (url ).path ).name , ** kwargs )
135
137
self .url = url
136
138
self .mirrors = mirrors
139
+ self ._resolved = False
140
+
141
+ def resolve (self ) -> OnlineResource :
142
+ if self ._resolved :
143
+ return self
144
+
145
+ redirect_url = _get_redirect_url (self .url )
146
+ if redirect_url == self .url :
147
+ self ._resolved = True
148
+ return self
149
+
150
+ meta = {
151
+ attr .lstrip ("_" ): getattr (self , attr )
152
+ for attr in (
153
+ "file_name" ,
154
+ "sha256" ,
155
+ "_preprocess" ,
156
+ "_loader" ,
157
+ )
158
+ }
159
+
160
+ gdrive_id = _get_google_drive_file_id (redirect_url )
161
+ if gdrive_id :
162
+ return GDriveResource (gdrive_id , ** meta )
163
+
164
+ http_resource = HttpResource (redirect_url , ** meta )
165
+ http_resource ._resolved = True
166
+ return http_resource
137
167
138
168
def _download (self , root : pathlib .Path ) -> None :
169
+ if not self ._resolved :
170
+ return self .resolve ()._download (root )
171
+
139
172
for url in itertools .chain ((self .url ,), self .mirrors ):
173
+
140
174
try :
141
175
download_url (url , str (root ), filename = self .file_name , md5 = None )
142
176
# TODO: make this more precise
You can’t perform that action at this time.
0 commit comments