1
1
#!/usr/bin/env python
2
2
3
3
import argparse
4
+ import base64
5
+ import dataclasses
6
+ import functools
4
7
import time
5
8
6
9
from os import path , makedirs
7
10
from datetime import datetime
8
11
from collections import defaultdict
9
- from typing import Iterator , List , Type , Dict , Set , TypeVar , Optional
12
+ from typing import Iterable , List , Type , Dict , Set , TypeVar , Optional
10
13
from re import sub , match , search
11
14
from packaging .version import parse
12
15
13
16
import boto3
14
17
15
18
16
19
S3 = boto3 .resource ('s3' )
17
- CLIENT = boto3 .client ('s3' )
18
20
BUCKET = S3 .Bucket ('pytorch' )
19
21
20
22
ACCEPTED_FILE_EXTENSIONS = ("whl" , "zip" , "tar.gz" )
107
109
108
110
S3IndexType = TypeVar ('S3IndexType' , bound = 'S3Index' )
109
111
112
+
113
+ @dataclasses .dataclass (frozen = True )
114
+ @functools .total_ordering
115
+ class S3Object :
116
+ key : str
117
+ checksum : str | None
118
+
119
+ def __str__ (self ):
120
+ return self .key
121
+
122
+ def __eq__ (self , other ):
123
+ return self .key == other .key
124
+
125
+ def __lt__ (self , other ):
126
+ return self .key < other .key
127
+
128
+
110
129
def extract_package_build_time (full_package_name : str ) -> datetime :
111
130
result = search (PACKAGE_DATE_REGEX , full_package_name )
112
131
if result is not None :
@@ -124,7 +143,7 @@ def between_bad_dates(package_build_time: datetime):
124
143
125
144
126
145
class S3Index :
127
- def __init__ (self : S3IndexType , objects : List [str ], prefix : str ) -> None :
146
+ def __init__ (self : S3IndexType , objects : List [S3Object ], prefix : str ) -> None :
128
147
self .objects = objects
129
148
self .prefix = prefix .rstrip ("/" )
130
149
self .html_name = PREFIXES_WITH_HTML [self .prefix ]
@@ -134,7 +153,7 @@ def __init__(self: S3IndexType, objects: List[str], prefix: str) -> None:
134
153
path .dirname (obj ) for obj in objects if path .dirname != prefix
135
154
}
136
155
137
- def nightly_packages_to_show (self : S3IndexType ) -> Set [str ]:
156
+ def nightly_packages_to_show (self : S3IndexType ) -> Set [S3Object ]:
138
157
"""Finding packages to show based on a threshold we specify
139
158
140
159
Basically takes our S3 packages, normalizes the version for easier
@@ -174,8 +193,8 @@ def nightly_packages_to_show(self: S3IndexType) -> Set[str]:
174
193
if self .normalize_package_version (obj ) in to_hide
175
194
})
176
195
177
- def is_obj_at_root (self , obj :str ) -> bool :
178
- return path .dirname (obj ) == self .prefix
196
+ def is_obj_at_root (self , obj : S3Object ) -> bool :
197
+ return path .dirname (str ( obj ) ) == self .prefix
179
198
180
199
def _resolve_subdir (self , subdir : Optional [str ] = None ) -> str :
181
200
if not subdir :
@@ -187,7 +206,7 @@ def gen_file_list(
187
206
self ,
188
207
subdir : Optional [str ]= None ,
189
208
package_name : Optional [str ] = None
190
- ) -> Iterator [ str ]:
209
+ ) -> Iterable [ S3Object ]:
191
210
objects = (
192
211
self .nightly_packages_to_show () if self .prefix == 'whl/nightly'
193
212
else self .objects
@@ -197,23 +216,23 @@ def gen_file_list(
197
216
if package_name is not None :
198
217
if self .obj_to_package_name (obj ) != package_name :
199
218
continue
200
- if self .is_obj_at_root (obj ) or obj .startswith (subdir ):
219
+ if self .is_obj_at_root (obj ) or str ( obj ) .startswith (subdir ):
201
220
yield obj
202
221
203
222
def get_package_names (self , subdir : Optional [str ] = None ) -> List [str ]:
204
223
return sorted (set (self .obj_to_package_name (obj ) for obj in self .gen_file_list (subdir )))
205
224
206
- def normalize_package_version (self : S3IndexType , obj : str ) -> str :
225
+ def normalize_package_version (self : S3IndexType , obj : S3Object ) -> str :
207
226
# removes the GPU specifier from the package name as well as
208
227
# unnecessary things like the file extension, architecture name, etc.
209
228
return sub (
210
229
r"%2B.*" ,
211
230
"" ,
212
- "-" .join (path .basename (obj ).split ("-" )[:2 ])
231
+ "-" .join (path .basename (str ( obj ) ).split ("-" )[:2 ])
213
232
)
214
233
215
- def obj_to_package_name (self , obj : str ) -> str :
216
- return path .basename (obj ).split ('-' , 1 )[0 ]
234
+ def obj_to_package_name (self , obj : S3Object ) -> str :
235
+ return path .basename (str ( obj ) ).split ('-' , 1 )[0 ]
217
236
218
237
def to_legacy_html (
219
238
self ,
@@ -258,7 +277,8 @@ def to_simple_package_html(
258
277
out .append (' <body>' )
259
278
out .append (' <h1>Links for {}</h1>' .format (package_name .lower ().replace ("_" ,"-" )))
260
279
for obj in sorted (self .gen_file_list (subdir , package_name )):
261
- out .append (f' <a href="/{ obj } ">{ path .basename (obj ).replace ("%2B" ,"+" )} </a><br/>' )
280
+ maybe_fragment = f"#sha256={ obj .checksum } " if obj .checksum else ""
281
+ out .append (f' <a href="/{ obj } { maybe_fragment } ">{ path .basename (obj ).replace ("%2B" ,"+" )} </a><br/>' )
262
282
# Adding html footer
263
283
out .append (' </body>' )
264
284
out .append ('</html>' )
@@ -319,7 +339,6 @@ def upload_pep503_htmls(self) -> None:
319
339
Body = self .to_simple_package_html (subdir = subdir , package_name = pkg_name )
320
340
)
321
341
322
-
323
342
def save_legacy_html (self ) -> None :
324
343
for subdir in self .subdirs :
325
344
print (f"INFO Saving { subdir } /{ self .html_name } " )
@@ -351,10 +370,18 @@ def from_S3(cls: Type[S3IndexType], prefix: str) -> S3IndexType:
351
370
for pattern in ACCEPTED_SUBDIR_PATTERNS
352
371
]) and obj .key .endswith (ACCEPTED_FILE_EXTENSIONS )
353
372
if is_acceptable :
373
+ # Add PEP 503-compatible hashes to URLs to allow clients to avoid spurious downloads, if possible.
374
+ response = obj .meta .client .head_object (Bucket = BUCKET .name , Key = obj .key , ChecksumMode = "ENABLED" )
375
+ sha256 = (_b64 := response .get ("ChecksumSHA256" )) and base64 .b64decode (_b64 ).hex ()
354
376
sanitized_key = obj .key .replace ("+" , "%2B" )
355
- objects .append (sanitized_key )
377
+ s3_object = S3Object (
378
+ key = sanitized_key ,
379
+ checksum = sha256 ,
380
+ )
381
+ objects .append (s3_object )
356
382
return cls (objects , prefix )
357
383
384
+
358
385
def create_parser () -> argparse .ArgumentParser :
359
386
parser = argparse .ArgumentParser ("Manage S3 HTML indices for PyTorch" )
360
387
parser .add_argument (
@@ -366,6 +393,7 @@ def create_parser() -> argparse.ArgumentParser:
366
393
parser .add_argument ("--generate-pep503" , action = "store_true" )
367
394
return parser
368
395
396
+
369
397
def main ():
370
398
parser = create_parser ()
371
399
args = parser .parse_args ()
@@ -390,5 +418,6 @@ def main():
390
418
if args .generate_pep503 :
391
419
idx .upload_pep503_htmls ()
392
420
421
+
393
422
if __name__ == "__main__" :
394
423
main ()
0 commit comments