Skip to content

Commit dbad8b7

Browse files
authored
Provide file hashes in the URLs to avoid unnecessary file downloads (bandwidth saver) (#1433)
Supply sha256 query parameters using boto3 to avoid hundreds of extra Gigabytes of downloads each day during pipenv and poetry resolution lock cycles. Fixes point 1 in pytorch/pytorch#76557 Fixes #1347
1 parent 553b4df commit dbad8b7

File tree

1 file changed

+44
-15
lines changed

1 file changed

+44
-15
lines changed

s3_management/manage.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
#!/usr/bin/env python
22

33
import argparse
4+
import base64
5+
import dataclasses
6+
import functools
47
import time
58

69
from os import path, makedirs
710
from datetime import datetime
811
from collections import defaultdict
9-
from typing import Iterator, List, Type, Dict, Set, TypeVar, Optional
12+
from typing import Iterable, List, Type, Dict, Set, TypeVar, Optional
1013
from re import sub, match, search
1114
from packaging.version import parse
1215

1316
import boto3
1417

1518

1619
S3 = boto3.resource('s3')
17-
CLIENT = boto3.client('s3')
1820
BUCKET = S3.Bucket('pytorch')
1921

2022
ACCEPTED_FILE_EXTENSIONS = ("whl", "zip", "tar.gz")
@@ -107,6 +109,23 @@
107109

108110
S3IndexType = TypeVar('S3IndexType', bound='S3Index')
109111

112+
113+
@dataclasses.dataclass(frozen=True)
114+
@functools.total_ordering
115+
class S3Object:
116+
key: str
117+
checksum: str | None
118+
119+
def __str__(self):
120+
return self.key
121+
122+
def __eq__(self, other):
123+
return self.key == other.key
124+
125+
def __lt__(self, other):
126+
return self.key < other.key
127+
128+
110129
def extract_package_build_time(full_package_name: str) -> datetime:
111130
result = search(PACKAGE_DATE_REGEX, full_package_name)
112131
if result is not None:
@@ -124,7 +143,7 @@ def between_bad_dates(package_build_time: datetime):
124143

125144

126145
class S3Index:
127-
def __init__(self: S3IndexType, objects: List[str], prefix: str) -> None:
146+
def __init__(self: S3IndexType, objects: List[S3Object], prefix: str) -> None:
128147
self.objects = objects
129148
self.prefix = prefix.rstrip("/")
130149
self.html_name = PREFIXES_WITH_HTML[self.prefix]
@@ -134,7 +153,7 @@ def __init__(self: S3IndexType, objects: List[str], prefix: str) -> None:
134153
path.dirname(obj) for obj in objects if path.dirname != prefix
135154
}
136155

137-
def nightly_packages_to_show(self: S3IndexType) -> Set[str]:
156+
def nightly_packages_to_show(self: S3IndexType) -> Set[S3Object]:
138157
"""Finding packages to show based on a threshold we specify
139158
140159
Basically takes our S3 packages, normalizes the version for easier
@@ -174,8 +193,8 @@ def nightly_packages_to_show(self: S3IndexType) -> Set[str]:
174193
if self.normalize_package_version(obj) in to_hide
175194
})
176195

177-
def is_obj_at_root(self, obj:str) -> bool:
178-
return path.dirname(obj) == self.prefix
196+
def is_obj_at_root(self, obj: S3Object) -> bool:
197+
return path.dirname(str(obj)) == self.prefix
179198

180199
def _resolve_subdir(self, subdir: Optional[str] = None) -> str:
181200
if not subdir:
@@ -187,7 +206,7 @@ def gen_file_list(
187206
self,
188207
subdir: Optional[str]=None,
189208
package_name: Optional[str] = None
190-
) -> Iterator[str]:
209+
) -> Iterable[S3Object]:
191210
objects = (
192211
self.nightly_packages_to_show() if self.prefix == 'whl/nightly'
193212
else self.objects
@@ -197,23 +216,23 @@ def gen_file_list(
197216
if package_name is not None:
198217
if self.obj_to_package_name(obj) != package_name:
199218
continue
200-
if self.is_obj_at_root(obj) or obj.startswith(subdir):
219+
if self.is_obj_at_root(obj) or str(obj).startswith(subdir):
201220
yield obj
202221

203222
def get_package_names(self, subdir: Optional[str] = None) -> List[str]:
204223
return sorted(set(self.obj_to_package_name(obj) for obj in self.gen_file_list(subdir)))
205224

206-
def normalize_package_version(self: S3IndexType, obj: str) -> str:
225+
def normalize_package_version(self: S3IndexType, obj: S3Object) -> str:
207226
# removes the GPU specifier from the package name as well as
208227
# unnecessary things like the file extension, architecture name, etc.
209228
return sub(
210229
r"%2B.*",
211230
"",
212-
"-".join(path.basename(obj).split("-")[:2])
231+
"-".join(path.basename(str(obj)).split("-")[:2])
213232
)
214233

215-
def obj_to_package_name(self, obj: str) -> str:
216-
return path.basename(obj).split('-', 1)[0]
234+
def obj_to_package_name(self, obj: S3Object) -> str:
235+
return path.basename(str(obj)).split('-', 1)[0]
217236

218237
def to_legacy_html(
219238
self,
@@ -258,7 +277,8 @@ def to_simple_package_html(
258277
out.append(' <body>')
259278
out.append(' <h1>Links for {}</h1>'.format(package_name.lower().replace("_","-")))
260279
for obj in sorted(self.gen_file_list(subdir, package_name)):
261-
out.append(f' <a href="/{obj}">{path.basename(obj).replace("%2B","+")}</a><br/>')
280+
maybe_fragment = f"#sha256={obj.checksum}" if obj.checksum else ""
281+
out.append(f' <a href="/{obj}{maybe_fragment}">{path.basename(obj).replace("%2B","+")}</a><br/>')
262282
# Adding html footer
263283
out.append(' </body>')
264284
out.append('</html>')
@@ -319,7 +339,6 @@ def upload_pep503_htmls(self) -> None:
319339
Body=self.to_simple_package_html(subdir=subdir, package_name=pkg_name)
320340
)
321341

322-
323342
def save_legacy_html(self) -> None:
324343
for subdir in self.subdirs:
325344
print(f"INFO Saving {subdir}/{self.html_name}")
@@ -351,10 +370,18 @@ def from_S3(cls: Type[S3IndexType], prefix: str) -> S3IndexType:
351370
for pattern in ACCEPTED_SUBDIR_PATTERNS
352371
]) and obj.key.endswith(ACCEPTED_FILE_EXTENSIONS)
353372
if is_acceptable:
373+
# Add PEP 503-compatible hashes to URLs to allow clients to avoid spurious downloads, if possible.
374+
response = obj.meta.client.head_object(Bucket=BUCKET.name, Key=obj.key, ChecksumMode="ENABLED")
375+
sha256 = (_b64 := response.get("ChecksumSHA256")) and base64.b64decode(_b64).hex()
354376
sanitized_key = obj.key.replace("+", "%2B")
355-
objects.append(sanitized_key)
377+
s3_object = S3Object(
378+
key=sanitized_key,
379+
checksum=sha256,
380+
)
381+
objects.append(s3_object)
356382
return cls(objects, prefix)
357383

384+
358385
def create_parser() -> argparse.ArgumentParser:
359386
parser = argparse.ArgumentParser("Manage S3 HTML indices for PyTorch")
360387
parser.add_argument(
@@ -366,6 +393,7 @@ def create_parser() -> argparse.ArgumentParser:
366393
parser.add_argument("--generate-pep503", action="store_true")
367394
return parser
368395

396+
369397
def main():
370398
parser = create_parser()
371399
args = parser.parse_args()
@@ -390,5 +418,6 @@ def main():
390418
if args.generate_pep503:
391419
idx.upload_pep503_htmls()
392420

421+
393422
if __name__ == "__main__":
394423
main()

0 commit comments

Comments
 (0)