diff --git a/.github/workflows/update-s3-html.yml b/.github/workflows/update-s3-html.yml
new file mode 100644
index 0000000000..1ba8ddc6ae
--- /dev/null
+++ b/.github/workflows/update-s3-html.yml
@@ -0,0 +1,45 @@
+name: Update S3 HTML indices for download.pytorch.org
+
+on:
+ schedule:
+ # Update the indices every 30 minutes
+ - cron: "*/30 * * * *"
+ workflow_dispatch:
+
+permissions:
+ id-token: write
+ contents: read
+
+jobs:
+ update:
+ runs-on: ubuntu-22.04
+ environment: pytorchbot-env
+ strategy:
+ matrix:
+ prefix: ["whl", "whl/test", "whl/nightly", "whl/lts/1.8"]
+ fail-fast: False
+ container:
+ image: continuumio/miniconda3:4.12.0
+ steps:
+ - name: configure aws credentials
+ id: aws_creds
+ uses: aws-actions/configure-aws-credentials@v3
+ with:
+ role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_update
+ aws-region: us-east-1
+ - name: Checkout repository test-infra
+ uses: actions/checkout@v3
+ with:
+ repository: pytorch/test-infra
+ ref: ${{ github.ref }}
+ - name: Update s3 html index
+ run: |
+ set -ex
+
+ # Create Conda Environment
+ conda create --quiet -y --prefix run_env python="3.8"
+ conda activate ./run_env
+
+ # Install requirements
+ pip install -r s3_management/requirements.txt
+ python s3_management/manage.py --generate-pep503 ${{ matrix.prefix }}
diff --git a/s3_management/README.md b/s3_management/README.md
new file mode 100644
index 0000000000..e2aab2661c
--- /dev/null
+++ b/s3_management/README.md
@@ -0,0 +1,3 @@
+# s3_management
+
+This directory houses scripts to maintain the s3 HTML indices for https://download.pytorch.org/whl
diff --git a/s3_management/backup_conda.py b/s3_management/backup_conda.py
new file mode 100644
index 0000000000..7dafa32b46
--- /dev/null
+++ b/s3_management/backup_conda.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python3
+# Downloads domain pytorch and library packages from channel
+# And backs them up to S3
+# Do not use unless you know what you are doing
+# Usage: python backup_conda.py --version 1.6.0
+
+import boto3
+from typing import List, Optional
+import conda.api
+import urllib
+import os
+import hashlib
+import argparse
+
+S3 = boto3.resource('s3')
+BUCKET = S3.Bucket('pytorch-backup')
+_known_subdirs = ["linux-64", "osx-64", "osx-arm64", "win-64"]
+
+
+def compute_md5(path:str) -> str:
+ with open(path, "rb") as f:
+ return hashlib.md5(f.read()).hexdigest()
+
+
+def download_conda_package(package:str, version:Optional[str] = None,
+ depends:Optional[str] = None, channel:Optional[str] = None) -> List[str]:
+ packages = conda.api.SubdirData.query_all(package,
+ channels = [channel] if channel is not None else None,
+ subdirs = _known_subdirs)
+ rc = []
+
+ for pkg in packages:
+ if version is not None and pkg.version != version:
+ continue
+ if depends is not None and depends not in pkg.depends:
+ continue
+
+ print(f"Downloading {pkg.url}...")
+ os.makedirs(pkg.subdir, exist_ok = True)
+ fname = f"{pkg.subdir}/{pkg.fn}"
+ if not os.path.exists(fname):
+ with open(fname, "wb") as f, urllib.request.urlopen(pkg.url) as url:
+ f.write(url.read())
+ if compute_md5(fname) != pkg.md5:
+ print(f"md5 of {fname} is {compute_md5(fname)} does not match {pkg.md5}")
+ continue
+ rc.append(fname)
+
+ return rc
+
+def upload_to_s3(prefix: str, fnames: List[str]) -> None:
+ for fname in fnames:
+ BUCKET.upload_file(fname, f"{prefix}/{fname}")
+ print(fname)
+
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--version",
+ help="PyTorch Version to backup",
+ type=str,
+ required = True
+ )
+ options = parser.parse_args()
+ rc = download_conda_package("pytorch", channel = "pytorch", version = options.version)
+ upload_to_s3(f"v{options.version}/conda", rc)
+
+ for libname in ["torchvision", "torchaudio", "torchtext"]:
+ print(f"processing {libname}")
+ rc = download_conda_package(libname, channel = "pytorch", depends = f"pytorch {options.version}")
+ upload_to_s3(f"v{options.version}/conda", rc)
diff --git a/s3_management/manage.py b/s3_management/manage.py
new file mode 100644
index 0000000000..47c151f087
--- /dev/null
+++ b/s3_management/manage.py
@@ -0,0 +1,508 @@
+#!/usr/bin/env python
+
+import argparse
+import base64
+import concurrent.futures
+import dataclasses
+import functools
+import time
+
+from os import path, makedirs
+from datetime import datetime
+from collections import defaultdict
+from typing import Iterable, List, Type, Dict, Set, TypeVar, Optional
+from re import sub, match, search
+from packaging.version import parse as _parse_version, Version, InvalidVersion
+
+import boto3
+
+
+S3 = boto3.resource('s3')
+CLIENT = boto3.client('s3')
+BUCKET = S3.Bucket('pytorch')
+
+ACCEPTED_FILE_EXTENSIONS = ("whl", "zip", "tar.gz")
+ACCEPTED_SUBDIR_PATTERNS = [
+ r"cu[0-9]+", # for cuda
+ r"rocm[0-9]+\.[0-9]+", # for rocm
+ "cpu",
+]
+PREFIXES_WITH_HTML = {
+ "whl": "torch_stable.html",
+ "whl/lts/1.8": "torch_lts.html",
+ "whl/nightly": "torch_nightly.html",
+ "whl/test": "torch_test.html",
+ "libtorch": "index.html",
+ "libtorch/nightly": "index.html",
+}
+
+# NOTE: This refers to the name on the wheels themselves and not the name of
+# package as specified by setuptools, for packages with "-" (hyphens) in their
+# names you need to convert them to "_" (underscores) in order for them to be
+# allowed here since the name of the wheels is compared here
+PACKAGE_ALLOW_LIST = {
+ "Pillow",
+ "certifi",
+ "charset_normalizer",
+ "cmake",
+ "colorama",
+ "fbgemm_gpu",
+ "filelock",
+ "fsspec",
+ "idna",
+ "Jinja2",
+ "lit",
+ "MarkupSafe",
+ "mpmath",
+ "nestedtensor",
+ "networkx",
+ "numpy",
+ "nvidia_cublas_cu11",
+ "nvidia_cuda_cupti_cu11",
+ "nvidia_cuda_nvrtc_cu11",
+ "nvidia_cuda_runtime_cu11",
+ "nvidia_cudnn_cu11",
+ "nvidia_cufft_cu11",
+ "nvidia_curand_cu11",
+ "nvidia_cusolver_cu11",
+ "nvidia_cusparse_cu11",
+ "nvidia_nccl_cu11",
+ "nvidia_nvtx_cu11",
+ "nvidia_cublas_cu12",
+ "nvidia_cuda_cupti_cu12",
+ "nvidia_cuda_nvrtc_cu12",
+ "nvidia_cuda_runtime_cu12",
+ "nvidia_cudnn_cu12",
+ "nvidia_cufft_cu12",
+ "nvidia_curand_cu12",
+ "nvidia_cusolver_cu12",
+ "nvidia_cusparse_cu12",
+ "nvidia_nccl_cu12",
+ "nvidia_nvtx_cu12",
+ "nvidia_nvjitlink_cu12",
+ "packaging",
+ "portalocker",
+ "pytorch_triton",
+ "pytorch_triton_rocm",
+ "requests",
+ "sympy",
+ "torch",
+ "torch_tensorrt",
+ "torcharrow",
+ "torchaudio",
+ "torchcsprng",
+ "torchdata",
+ "torchdistx",
+ "torchmetrics",
+ "torchrec",
+ "torchtext",
+ "torchvision",
+ "triton",
+ "tqdm",
+ "typing_extensions",
+ "urllib3",
+ "xformers",
+}
+
+# Should match torch-2.0.0.dev20221221+cu118-cp310-cp310-linux_x86_64.whl as:
+# Group 1: torch-2.0.0.dev
+# Group 2: 20221221
+PACKAGE_DATE_REGEX = r"([a-zA-z]*-[0-9.]*.dev)([0-9]*)"
+
+# How many packages should we keep of a specific package?
+KEEP_THRESHOLD = 60
+
+S3IndexType = TypeVar('S3IndexType', bound='S3Index')
+
+
+@dataclasses.dataclass(frozen=False)
+@functools.total_ordering
+class S3Object:
+ key: str
+ orig_key: str
+ checksum: Optional[str]
+ size: Optional[int]
+
+ def __hash__(self):
+ return hash(self.key)
+
+ def __str__(self):
+ return self.key
+
+ def __eq__(self, other):
+ return self.key == other.key
+
+ def __lt__(self, other):
+ return self.key < other.key
+
+
+def extract_package_build_time(full_package_name: str) -> datetime:
+ result = search(PACKAGE_DATE_REGEX, full_package_name)
+ if result is not None:
+ try:
+ return datetime.strptime(result.group(2), "%Y%m%d")
+ except ValueError:
+ # Ignore any value errors since they probably shouldn't be hidden anyways
+ pass
+ return datetime.now()
+
+
+def between_bad_dates(package_build_time: datetime):
+ start_bad = datetime(year=2022, month=8, day=17)
+ end_bad = datetime(year=2022, month=12, day=30)
+ return start_bad <= package_build_time <= end_bad
+
+
+def safe_parse_version(ver_str: str) -> Version:
+ try:
+ return _parse_version(ver_str)
+ except InvalidVersion:
+ return Version("0.0.0")
+
+
+
+class S3Index:
+ def __init__(self: S3IndexType, objects: List[S3Object], prefix: str) -> None:
+ self.objects = objects
+ self.prefix = prefix.rstrip("/")
+ self.html_name = PREFIXES_WITH_HTML[self.prefix]
+ # should dynamically grab subdirectories like whl/test/cu101
+ # so we don't need to add them manually anymore
+ self.subdirs = {
+ path.dirname(obj.key) for obj in objects if path.dirname != prefix
+ }
+
+ def nightly_packages_to_show(self: S3IndexType) -> List[S3Object]:
+ """Finding packages to show based on a threshold we specify
+
+ Basically takes our S3 packages, normalizes the version for easier
+ comparisons, then iterates over normalized versions until we reach a
+ threshold and then starts adding package to delete after that threshold
+ has been reached
+
+ After figuring out what versions we'd like to hide we iterate over
+ our original object list again and pick out the full paths to the
+ packages that are included in the list of versions to delete
+ """
+ # also includes versions without GPU specifier (i.e. cu102) for easier
+ # sorting, sorts in reverse to put the most recent versions first
+ all_sorted_packages = sorted(
+ {self.normalize_package_version(obj) for obj in self.objects},
+ key=lambda name_ver: safe_parse_version(name_ver.split('-', 1)[-1]),
+ reverse=True,
+ )
+ packages: Dict[str, int] = defaultdict(int)
+ to_hide: Set[str] = set()
+ for obj in all_sorted_packages:
+ full_package_name = path.basename(obj)
+ package_name = full_package_name.split('-')[0]
+ package_build_time = extract_package_build_time(full_package_name)
+ # Hard pass on packages that are included in our allow list
+ if package_name not in PACKAGE_ALLOW_LIST:
+ to_hide.add(obj)
+ continue
+ if packages[package_name] >= KEEP_THRESHOLD or between_bad_dates(package_build_time):
+ to_hide.add(obj)
+ else:
+ packages[package_name] += 1
+ return list(set(self.objects).difference({
+ obj for obj in self.objects
+ if self.normalize_package_version(obj) in to_hide
+ }))
+
+ def is_obj_at_root(self, obj: S3Object) -> bool:
+ return path.dirname(obj.key) == self.prefix
+
+ def _resolve_subdir(self, subdir: Optional[str] = None) -> str:
+ if not subdir:
+ subdir = self.prefix
+ # make sure we strip any trailing slashes
+ return subdir.rstrip("/")
+
+ def gen_file_list(
+ self,
+ subdir: Optional[str] = None,
+ package_name: Optional[str] = None
+ ) -> Iterable[S3Object]:
+ objects = self.objects
+ subdir = self._resolve_subdir(subdir) + '/'
+ for obj in objects:
+ if package_name is not None and self.obj_to_package_name(obj) != package_name:
+ continue
+ if self.is_obj_at_root(obj) or obj.key.startswith(subdir):
+ yield obj
+
+ def get_package_names(self, subdir: Optional[str] = None) -> List[str]:
+ return sorted({self.obj_to_package_name(obj) for obj in self.gen_file_list(subdir)})
+
+ def normalize_package_version(self: S3IndexType, obj: S3Object) -> str:
+ # removes the GPU specifier from the package name as well as
+ # unnecessary things like the file extension, architecture name, etc.
+ return sub(
+ r"%2B.*",
+ "",
+ "-".join(path.basename(obj.key).split("-")[:2])
+ )
+
+ def obj_to_package_name(self, obj: S3Object) -> str:
+ return path.basename(obj.key).split('-', 1)[0]
+
+ def to_legacy_html(
+ self,
+ subdir: Optional[str] = None
+ ) -> str:
+ """Generates a string that can be used as the HTML index
+
+ Takes our objects and transforms them into HTML that have historically
+ been used by pip for installing pytorch.
+
+ NOTE: These are not PEP 503 compliant but are here for legacy purposes
+ """
+ out: List[str] = []
+ subdir = self._resolve_subdir(subdir)
+ is_root = subdir == self.prefix
+ for obj in self.gen_file_list(subdir):
+ # Strip our prefix
+ sanitized_obj = obj.key.replace(subdir, "", 1)
+ if sanitized_obj.startswith('/'):
+ sanitized_obj = sanitized_obj.lstrip("/")
+ # we include objects at our root prefix so that users can still
+ # install packages like torchaudio / torchtext even if they want
+ # to install a specific GPU arch of torch / torchvision
+ if not is_root and self.is_obj_at_root(obj):
+ # strip root prefix
+ sanitized_obj = obj.key.replace(self.prefix, "", 1).lstrip("/")
+ sanitized_obj = f"../{sanitized_obj}"
+ out.append(f'{sanitized_obj}
')
+ return "\n".join(sorted(out))
+
+ def to_simple_package_html(
+ self,
+ subdir: Optional[str],
+ package_name: str
+ ) -> str:
+ """Generates a string that can be used as the package simple HTML index
+ """
+ out: List[str] = []
+ # Adding html header
+ out.append('')
+ out.append('')
+ out.append('