diff --git a/.bazelrc b/.bazelrc index 7581b5243021..fc2995dc838c 100644 --- a/.bazelrc +++ b/.bazelrc @@ -2,7 +2,7 @@ build --cxxopt=--std=c++17 build --copt=-I. # Bazel does not support including its cc_library targets as system # headers. We work around this for generated code -# (e.g. c10/macros/cmake_macros.h) by making the generated directory a +# (e.g. torch/headeronly/macros/cmake_macros.h) by making the generated directory a # system include path. build --copt=-isystem --copt bazel-out/k8-fastbuild/bin build --copt=-isystem --copt bazel-out/darwin-fastbuild/bin diff --git a/.ci/caffe2/test.sh b/.ci/caffe2/test.sh index eaef1e3ebf88..7d1ce2fb4fa1 100755 --- a/.ci/caffe2/test.sh +++ b/.ci/caffe2/test.sh @@ -5,7 +5,7 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh" if [[ ${BUILD_ENVIRONMENT} == *onnx* ]]; then pip install click mock tabulate networkx==2.0 - pip -q install --user "file:///var/lib/jenkins/workspace/third_party/onnx#egg=onnx" + pip -q install "file:///var/lib/jenkins/workspace/third_party/onnx#egg=onnx" fi # Skip tests in environments where they are not built/applicable @@ -147,8 +147,8 @@ export DNNL_MAX_CPU_ISA=AVX2 if [[ "${SHARD_NUMBER:-1}" == "1" ]]; then # TODO(sdym@meta.com) remove this when the linked issue resolved. # py is temporary until https://github.com/Teemu/pytest-sugar/issues/241 is fixed - pip install --user py==1.11.0 - pip install --user pytest-sugar + pip install py==1.11.0 + pip install pytest-sugar # NB: Warnings are disabled because they make it harder to see what # the actual erroring test is "$PYTHON" \ diff --git a/.ci/docker/README.md b/.ci/docker/README.md index 15779155933e..0fd4ed7ca502 100644 --- a/.ci/docker/README.md +++ b/.ci/docker/README.md @@ -36,3 +36,105 @@ See `build.sh` for valid build environments (it's the giant switch). # Set flags (see build.sh) and build image sudo bash -c 'TRITON=1 ./build.sh pytorch-linux-bionic-py3.8-gcc9 -t myimage:latest ``` + +## [Guidance] Adding a New Base Docker Image + +### Background + +The base Docker images in directory `.ci/docker/` are built by the `docker-builds.yml` workflow. Those images are used throughout the PyTorch CI/CD pipeline. You should only create or modify a base Docker image if you need specific environment changes or dependencies before building PyTorch on CI. + +1. **Automatic Rebuilding**: + - The Docker image building process is triggered automatically when changes are made to files in the `.ci/docker/*` directory + - This ensures all images stay up-to-date with the latest dependencies and configurations + +2. **Image Reuse in PyTorch Build Workflows** (example: linux-build): + - The images generated by `docker-builds.yml` are reused in `_linux-build.yml` through the `calculate-docker-image` step + - The `_linux-build.yml` workflow: + - Pulls the Docker image determined by the `calculate-docker-image` step + - Runs a Docker container with that image + - Executes `.ci/pytorch/build.sh` inside the container to build PyTorch + +3. **Usage in Test Workflows** (example: linux-test): + - The same Docker images are also used in `_linux-test.yml` for running tests + - The `_linux-test.yml` workflow follows a similar pattern: + - It uses the `calculate-docker-image` step to determine which Docker image to use + - It pulls the Docker image and runs a container with that image + - It installs the wheels from the artifacts generated by PyTorch build jobs + - It executes test scripts (like `.ci/pytorch/test.sh` or `.ci/pytorch/multigpu-test.sh`) inside the container + +### Understanding File Purposes + +#### `.ci/docker/build.sh` vs `.ci/pytorch/build.sh` +- **`.ci/docker/build.sh`**: + - Used for building base Docker images + - Executed by the `docker-builds.yml` workflow to pre-build Docker images for CI + - Contains configurations for different Docker build environments + +- **`.ci/pytorch/build.sh`**: + - Used for building PyTorch inside a Docker container + - Called by workflows like `_linux-build.yml` after the Docker container is started + - Builds PyTorch wheels and other artifacts + +#### `.ci/docker/ci_commit_pins/` vs `.github/ci_commit_pins` +- **`.ci/docker/ci_commit_pins/`**: + - Used for pinning dependency versions during base Docker image building + - Ensures consistent environments for building PyTorch + - Changes here trigger base Docker image rebuilds + +- **`.github/ci_commit_pins`**: + - Used for pinning dependency versions during PyTorch building and tests + - Ensures consistent dependencies for PyTorch across different builds + - Used by build scripts running inside Docker containers + +### Step-by-Step Guide for Adding a New Base Docker Image + +#### 1. Add Pinned Commits (If Applicable) + +We use pinned commits for build stability. The `nightly.yml` workflow checks and updates pinned commits for certain repository dependencies daily. + +If your new Docker image needs a library installed from a specific pinned commit or built from source: + +1. Add the repository you want to track in `nightly.yml` and `merge-rules.yml` +2. Add the initial pinned commit in `.ci/docker/ci_commit_pins/`. The text filename should match the one defined in step 1 + +#### 2. Configure the Base Docker Image +1. **Add new Base Docker image configuration** (if applicable): + + Add the configuration in `.ci/docker/build.sh`. For example: + ```bash + pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-new1) + CUDA_VERSION=12.8.1 + CUDNN_VERSION=9 + ANACONDA_PYTHON_VERSION=3.12 + GCC_VERSION=11 + VISION=yes + KATEX=yes + UCX_COMMIT=${_UCX_COMMIT} + UCC_COMMIT=${_UCC_COMMIT} + TRITON=yes + NEW_ARG_1=yes + ;; + ``` + +2. **Add build arguments to Docker build command**: + + If you're introducing a new argument to the Docker build, make sure to add it in the Docker build step in `.ci/docker/build.sh`: + ```bash + docker build \ + .... + --build-arg "NEW_ARG_1=${NEW_ARG_1}" + ``` + +3. **Update Dockerfile logic**: + + Update the Dockerfile to use the new argument. For example, in `ubuntu/Dockerfile`: + ```dockerfile + ARG NEW_ARG_1 + # Set up environment for NEW_ARG_1 + RUN if [ -n "${NEW_ARG_1}" ]; then bash ./do_something.sh; fi + ``` + +4. **Add the Docker configuration** in `.github/workflows/docker-builds.yml`: + + The `docker-builds.yml` workflow pre-builds the Docker images whenever changes occur in the `.ci/docker/` directory. This includes the + pinned commit updates. diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 97e6bce3e59d..cf022d099326 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -91,6 +91,17 @@ tag=$(echo $image | awk -F':' '{print $2}') # configuration, so we hardcode everything here rather than do it # from scratch case "$tag" in + pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11) + CUDA_VERSION=12.4 + CUDNN_VERSION=9 + ANACONDA_PYTHON_VERSION=3.10 + GCC_VERSION=11 + VISION=yes + KATEX=yes + UCX_COMMIT=${_UCX_COMMIT} + UCC_COMMIT=${_UCC_COMMIT} + TRITON=yes + ;; pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11) CUDA_VERSION=12.8.1 CUDNN_VERSION=9 @@ -149,6 +160,17 @@ case "$tag" in UCC_COMMIT=${_UCC_COMMIT} TRITON=yes ;; + pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm) + CUDA_VERSION=12.8.1 + CUDNN_VERSION=9 + ANACONDA_PYTHON_VERSION=3.12 + GCC_VERSION=11 + VISION=yes + KATEX=yes + UCX_COMMIT=${_UCX_COMMIT} + UCC_COMMIT=${_UCC_COMMIT} + TRITON=yes + ;; pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9-inductor-benchmarks) CUDA_VERSION=12.6 CUDNN_VERSION=9 @@ -220,11 +242,15 @@ case "$tag" in VISION=yes TRITON=yes ;; - pytorch-linux-jammy-rocm-n-1-py3) - ANACONDA_PYTHON_VERSION=3.10 + pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-noble-rocm-n-py3) + if [[ $tag =~ "jammy" ]]; then + ANACONDA_PYTHON_VERSION=3.10 + else + ANACONDA_PYTHON_VERSION=3.12 + fi GCC_VERSION=11 VISION=yes - ROCM_VERSION=6.3 + ROCM_VERSION=6.4 NINJA_VERSION=1.9.0 TRITON=yes KATEX=yes @@ -232,21 +258,18 @@ case "$tag" in UCC_COMMIT=${_UCC_COMMIT} INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-noble-rocm-n-py3) - if [[ $tag =~ "jammy" ]]; then - ANACONDA_PYTHON_VERSION=3.10 - else - ANACONDA_PYTHON_VERSION=3.12 - fi + pytorch-linux-noble-rocm-alpha-py3) + ANACONDA_PYTHON_VERSION=3.12 GCC_VERSION=11 VISION=yes - ROCM_VERSION=6.4 + ROCM_VERSION=7.0 NINJA_VERSION=1.9.0 TRITON=yes KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} INDUCTOR_BENCHMARKS=yes + PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950" ;; pytorch-linux-jammy-xpu-2025.0-py3) ANACONDA_PYTHON_VERSION=3.9 @@ -264,7 +287,7 @@ case "$tag" in NINJA_VERSION=1.9.0 TRITON=yes ;; - pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks) + pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks) ANACONDA_PYTHON_VERSION=3.9 GCC_VERSION=11 VISION=yes diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 568756a804f0..6dc1c44507eb 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -ae848267bebc65c6181e8cc5e64a6357d2679260 +11ec6354315768a85da41032535e3b7b99c5f706 diff --git a/.ci/docker/common/install_conda.sh b/.ci/docker/common/install_conda.sh index 185837b7e98a..481de54a50f2 100755 --- a/.ci/docker/common/install_conda.sh +++ b/.ci/docker/common/install_conda.sh @@ -4,12 +4,8 @@ set -ex # Optionally install conda if [ -n "$ANACONDA_PYTHON_VERSION" ]; then - BASE_URL="https://repo.anaconda.com/miniconda" - CONDA_FILE="Miniconda3-latest-Linux-x86_64.sh" - if [[ $(uname -m) == "aarch64" ]] || [[ "$BUILD_ENVIRONMENT" == *xpu* ]] || [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then - BASE_URL="https://github.com/conda-forge/miniforge/releases/latest/download" # @lint-ignore - CONDA_FILE="Miniforge3-Linux-$(uname -m).sh" - fi + BASE_URL="https://github.com/conda-forge/miniforge/releases/latest/download" # @lint-ignore + CONDA_FILE="Miniforge3-Linux-$(uname -m).sh" MAJOR_PYTHON_VERSION=$(echo "$ANACONDA_PYTHON_VERSION" | cut -d . -f 1) MINOR_PYTHON_VERSION=$(echo "$ANACONDA_PYTHON_VERSION" | cut -d . -f 2) @@ -21,7 +17,6 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then exit 1 ;; esac - mkdir -p /opt/conda chown jenkins:jenkins /opt/conda diff --git a/.ci/docker/common/install_cuda.sh b/.ci/docker/common/install_cuda.sh index cd9701e7590b..c8a780f65c8e 100644 --- a/.ci/docker/common/install_cuda.sh +++ b/.ci/docker/common/install_cuda.sh @@ -78,6 +78,19 @@ function install_nvshmem { echo "nvSHMEM ${nvshmem_version} for CUDA ${cuda_major_version} (${arch_path}) installed." } +function install_124 { + CUDNN_VERSION=9.1.0.70 + echo "Installing CUDA 12.4.1 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.6.2" + install_cuda 12.4.1 cuda_12.4.1_550.54.15_linux + + install_cudnn 12 $CUDNN_VERSION + + CUDA_VERSION=12.4 bash install_nccl.sh + + CUDA_VERSION=12.4 bash install_cusparselt.sh + + ldconfig +} function install_126 { CUDNN_VERSION=9.10.2.21 @@ -113,6 +126,40 @@ function install_129 { ldconfig } +function prune_124 { + echo "Pruning CUDA 12.4" + ##################################################################################### + # CUDA 12.4 prune static libs + ##################################################################################### + export NVPRUNE="/usr/local/cuda-12.4/bin/nvprune" + export CUDA_LIB_DIR="/usr/local/cuda-12.4/lib64" + + export GENCODE="-gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90" + export GENCODE_CUDNN="-gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90" + + if [[ -n "$OVERRIDE_GENCODE" ]]; then + export GENCODE=$OVERRIDE_GENCODE + fi + if [[ -n "$OVERRIDE_GENCODE_CUDNN" ]]; then + export GENCODE_CUDNN=$OVERRIDE_GENCODE_CUDNN + fi + + # all CUDA libs except CuDNN and CuBLAS + ls $CUDA_LIB_DIR/ | grep "\.a" | grep -v "culibos" | grep -v "cudart" | grep -v "cudnn" | grep -v "cublas" | grep -v "metis" \ + | xargs -I {} bash -c \ + "echo {} && $NVPRUNE $GENCODE $CUDA_LIB_DIR/{} -o $CUDA_LIB_DIR/{}" + + # prune CuDNN and CuBLAS + $NVPRUNE $GENCODE_CUDNN $CUDA_LIB_DIR/libcublas_static.a -o $CUDA_LIB_DIR/libcublas_static.a + $NVPRUNE $GENCODE_CUDNN $CUDA_LIB_DIR/libcublasLt_static.a -o $CUDA_LIB_DIR/libcublasLt_static.a + + ##################################################################################### + # CUDA 12.4 prune visual tools + ##################################################################################### + export CUDA_BASE="/usr/local/cuda-12.4/" + rm -rf $CUDA_BASE/libnvvp $CUDA_BASE/nsightee_plugins $CUDA_BASE/nsight-compute-2024.1.0 $CUDA_BASE/nsight-systems-2023.4.4/ +} + function prune_126 { echo "Pruning CUDA 12.6" ##################################################################################### @@ -169,6 +216,8 @@ function install_128 { while test $# -gt 0 do case "$1" in + 12.4) install_124; prune_124 + ;; 12.6|12.6.*) install_126; prune_126 ;; 12.8|12.8.*) install_128; diff --git a/.ci/docker/common/install_cudnn.sh b/.ci/docker/common/install_cudnn.sh index 7ee5e73226cb..fecdb448589e 100644 --- a/.ci/docker/common/install_cudnn.sh +++ b/.ci/docker/common/install_cudnn.sh @@ -8,6 +8,8 @@ if [[ -n "${CUDNN_VERSION}" ]]; then CUDNN_NAME="cudnn-linux-x86_64-9.10.2.21_cuda12-archive" elif [[ ${CUDA_VERSION:0:4} == "12.6" ]]; then CUDNN_NAME="cudnn-linux-x86_64-9.10.2.21_cuda12-archive" + elif [[ ${CUDA_VERSION:0:4} == "12.4" ]]; then + CUDNN_NAME="cudnn-linux-x86_64-9.10.2.21_cuda12-archive" elif [[ ${CUDA_VERSION:0:2} == "11" ]]; then CUDNN_NAME="cudnn-linux-x86_64-9.1.0.70_cuda11-archive" else diff --git a/.ci/docker/common/install_cusparselt.sh b/.ci/docker/common/install_cusparselt.sh index ca29a94e58fc..feacb49f39eb 100644 --- a/.ci/docker/common/install_cusparselt.sh +++ b/.ci/docker/common/install_cusparselt.sh @@ -13,6 +13,14 @@ if [[ ${CUDA_VERSION:0:4} =~ ^12\.[5-9]$ ]]; then fi CUSPARSELT_NAME="libcusparse_lt-linux-${arch_path}-0.7.1.0-archive" curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-${arch_path}/${CUSPARSELT_NAME}.tar.xz +elif [[ ${CUDA_VERSION:0:4} == "12.4" ]]; then + arch_path='sbsa' + export TARGETARCH=${TARGETARCH:-$(uname -m)} + if [ ${TARGETARCH} = 'amd64' ] || [ "${TARGETARCH}" = 'x86_64' ]; then + arch_path='x86_64' + fi + CUSPARSELT_NAME="libcusparse_lt-linux-${arch_path}-0.6.2.3-archive" + curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-${arch_path}/${CUSPARSELT_NAME}.tar.xz else echo "Not sure which libcusparselt version to install for this ${CUDA_VERSION}" fi diff --git a/.ci/docker/common/install_rocm.sh b/.ci/docker/common/install_rocm.sh index 39a3f0eaf1c4..02406ab71cde 100644 --- a/.ci/docker/common/install_rocm.sh +++ b/.ci/docker/common/install_rocm.sh @@ -33,13 +33,22 @@ EOF ROCM_VERSION="${ROCM_VERSION}.1" fi + # Default url values + rocm_baseurl="http://repo.radeon.com/rocm/apt/${ROCM_VERSION}" + amdgpu_baseurl="https://repo.radeon.com/amdgpu/${ROCM_VERSION}/ubuntu" + + # Special case for ROCM_VERSION == 7.0 + if [[ $(ver "$ROCM_VERSION") -eq $(ver 7.0) ]]; then + rocm_baseurl="https://repo.radeon.com/rocm/apt/7.0_alpha2" + amdgpu_baseurl="https://repo.radeon.com/amdgpu/30.10_alpha2/ubuntu" + fi + # Add amdgpu repository UBUNTU_VERSION_NAME=`cat /etc/os-release | grep UBUNTU_CODENAME | awk -F= '{print $2}'` - echo "deb [arch=amd64] https://repo.radeon.com/amdgpu/${ROCM_VERSION}/ubuntu ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/amdgpu.list + echo "deb [arch=amd64] ${amdgpu_baseurl} ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/amdgpu.list # Add rocm repository wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - - local rocm_baseurl="http://repo.radeon.com/rocm/apt/${ROCM_VERSION}" echo "deb [arch=amd64] ${rocm_baseurl} ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/rocm.list apt-get update --allow-insecure-repositories @@ -73,30 +82,30 @@ EOF done # ROCm 6.3 had a regression where initializing static code objects had significant overhead + # CI no longer builds for ROCm 6.3, but # ROCm 6.4 did not yet fix the regression, also HIP branch names are different - if [[ $(ver $ROCM_VERSION) -ge $(ver 6.3) ]] && [[ $(ver $ROCM_VERSION) -lt $(ver 7.0) ]]; then + if [[ $(ver $ROCM_VERSION) -ge $(ver 6.4) ]] && [[ $(ver $ROCM_VERSION) -lt $(ver 7.0) ]]; then if [[ $(ver $ROCM_VERSION) -eq $(ver 6.4.1) ]]; then HIP_BRANCH=release/rocm-rel-6.4 - VER_STR=6.4 - VER_PATCH=.1 + CLR_HASH=606bc820b4b1f315d135da02a1f0b176ca50a92c # branch release/rocm-rel-6.4.1-statco-hotfix elif [[ $(ver $ROCM_VERSION) -eq $(ver 6.4) ]]; then HIP_BRANCH=release/rocm-rel-6.4 - VER_STR=6.4 - elif [[ $(ver $ROCM_VERSION) -eq $(ver 6.3) ]]; then - HIP_BRANCH=rocm-6.3.x - VER_STR=6.3 + CLR_HASH=600f5b0d2baed94d5121e2174a9de0851b040b0c # branch release/rocm-rel-6.4-statco-hotfix fi # clr build needs CppHeaderParser but can only find it using conda's python python -m pip install CppHeaderParser git clone https://github.com/ROCm/HIP -b $HIP_BRANCH HIP_COMMON_DIR=$(readlink -f HIP) - git clone https://github.com/jeffdaily/clr -b release/rocm-rel-${VER_STR}${VER_PATCH}-statco-hotfix + git clone https://github.com/jeffdaily/clr + pushd clr + git checkout $CLR_HASH + popd mkdir -p clr/build pushd clr/build # Need to point CMake to the correct python installation to find CppHeaderParser cmake .. -DPython3_EXECUTABLE=/opt/conda/envs/py_${ANACONDA_PYTHON_VERSION}/bin/python3 -DCLR_BUILD_HIP=ON -DHIP_COMMON_DIR=$HIP_COMMON_DIR make -j - cp hipamd/lib/libamdhip64.so.${VER_STR}.* /opt/rocm/lib/libamdhip64.so.${VER_STR}.* + cp hipamd/lib/libamdhip64.so.6.4.* /opt/rocm/lib/libamdhip64.so.6.4.* popd rm -rf HIP clr fi diff --git a/.ci/docker/linter/Dockerfile b/.ci/docker/linter/Dockerfile index 0fdfac678d40..95d08ffea051 100644 --- a/.ci/docker/linter/Dockerfile +++ b/.ci/docker/linter/Dockerfile @@ -27,5 +27,7 @@ COPY ./common/install_linter.sh install_linter.sh RUN bash ./install_linter.sh RUN rm install_linter.sh +RUN chown -R jenkins:jenkins /var/lib/jenkins/ci_env + USER jenkins CMD ["bash"] diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 9c8251989477..fb773ff324af 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -389,3 +389,9 @@ tlparse==0.3.30 cuda-bindings>=12.0,<13.0 ; platform_machine != "s390x" #Description: required for testing CUDAGraph::raw_cuda_graph(). See https://nvidia.github.io/cuda-python/cuda-bindings/latest/support.html for how this version was chosen. Note "Any fix in the latest bindings would be backported to the prior major version" means that only the newest version of cuda-bindings will get fixes. Depending on the latest version of 12.x is okay because all 12.y versions will be supported via "CUDA minor version compatibility". Pytorch builds against 13.z versions of cuda toolkit work with 12.x versions of cuda-bindings as well because newer drivers work with old toolkits. #test that import: test_cuda.py + +setuptools-git-versioning==2.1.0 +scikit-build==0.18.1 +pyre-extensions==0.0.32 +tabulate==0.9.0 +#Description: These package are needed to build FBGEMM and torchrec on PyTorch CI diff --git a/.ci/docker/requirements-docs.txt b/.ci/docker/requirements-docs.txt index 54e9dbdfca26..8ff9f07c84a8 100644 --- a/.ci/docker/requirements-docs.txt +++ b/.ci/docker/requirements-docs.txt @@ -4,7 +4,7 @@ sphinx==5.3.0 -e git+https://github.com/pytorch/pytorch_sphinx_theme.git@pytorch_sphinx_theme2#egg=pytorch_sphinx_theme2 # TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering -# but it doesn't seem to work and hangs around idly. The initial thought is probably +# but it doesn't seem to work and hangs around idly. The initial thought that it is probably # something related to Docker setup. We can investigate this later. sphinxcontrib.katex==0.8.6 @@ -59,3 +59,4 @@ sphinx-copybutton==0.5.0 sphinx-design==0.4.0 sphinxcontrib-mermaid==1.0.0 myst-parser==0.18.1 +myst-nb diff --git a/.ci/manywheel/build_common.sh b/.ci/manywheel/build_common.sh index 6437cf9f0d48..49549c9f2994 100644 --- a/.ci/manywheel/build_common.sh +++ b/.ci/manywheel/build_common.sh @@ -97,8 +97,7 @@ if [[ -z "$PYTORCH_ROOT" ]]; then exit 1 fi pushd "$PYTORCH_ROOT" -retry pip install -q "setuptools>=70.1.0" packaging -retry pip install -qU cmake ninja +retry pip install -qUr requirements-build.txt python setup.py clean retry pip install -qr requirements.txt case ${DESIRED_PYTHON} in diff --git a/.ci/manywheel/build_libtorch.sh b/.ci/manywheel/build_libtorch.sh index 30a723cb1095..4de775b1823c 100644 --- a/.ci/manywheel/build_libtorch.sh +++ b/.ci/manywheel/build_libtorch.sh @@ -92,8 +92,7 @@ if [[ -z "$PYTORCH_ROOT" ]]; then exit 1 fi pushd "$PYTORCH_ROOT" -retry pip install -q "setuptools>=70.1.0" packaging -retry pip install -qU cmake ninja +retry pip install -qUr requirements-build.txt python setup.py clean retry pip install -qr requirements.txt retry pip install -q numpy==2.0.1 diff --git a/.ci/onnx/test.sh b/.ci/onnx/test.sh index a7d3b72c62a7..d42ca2c218de 100755 --- a/.ci/onnx/test.sh +++ b/.ci/onnx/test.sh @@ -19,7 +19,7 @@ git config --global --add safe.directory /var/lib/jenkins/workspace if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then # TODO: This can be removed later once vision is also part of the Docker image - pip install -q --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)" + pip install -q --no-use-pep517 "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)" # JIT C++ extensions require ninja, so put it into PATH. export PATH="/var/lib/jenkins/.local/bin:$PATH" # NB: ONNX test is fast (~15m) so it's ok to retry it few more times to avoid any flaky issue, we diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index 994bd179e464..58454bcb108a 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -306,6 +306,22 @@ else fi pip_install_whl "$(echo dist/*.whl)" + if [[ "${BUILD_ADDITIONAL_PACKAGES:-}" == *vision* ]]; then + install_torchvision + fi + + if [[ "${BUILD_ADDITIONAL_PACKAGES:-}" == *audio* ]]; then + install_torchaudio + fi + + if [[ "${BUILD_ADDITIONAL_PACKAGES:-}" == *torchrec* || "${BUILD_ADDITIONAL_PACKAGES:-}" == *fbgemm* ]]; then + install_torchrec_and_fbgemm + fi + + if [[ "${BUILD_ADDITIONAL_PACKAGES:-}" == *torchao* ]]; then + install_torchao + fi + if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then echo "Checking that xpu is compiled" pushd dist/ diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index 9c0e5242f433..e9c7741947cf 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -78,6 +78,34 @@ function pip_install_whl() { fi } +function pip_build_and_install() { + local build_target=$1 + local wheel_dir=$2 + + local found_whl=0 + for file in "${wheel_dir}"/*.whl + do + if [[ -f "${file}" ]]; then + found_whl=1 + break + fi + done + + # Build the wheel if it doesn't exist + if [ "${found_whl}" == "0" ]; then + python3 -m pip wheel \ + --no-build-isolation \ + --no-deps \ + --no-use-pep517 \ + -w "${wheel_dir}" \ + "${build_target}" + fi + + for file in "${wheel_dir}"/*.whl + do + pip_install_whl "${file}" + done +} function pip_install() { # retry 3 times @@ -124,14 +152,7 @@ function get_pinned_commit() { function install_torchaudio() { local commit commit=$(get_pinned_commit audio) - if [[ "$1" == "cuda" ]]; then - # TODO: This is better to be passed as a parameter from _linux-test workflow - # so that it can be consistent with what is set in build - TORCH_CUDA_ARCH_LIST="8.0;8.6" pip_install --no-use-pep517 --user "git+https://github.com/pytorch/audio.git@${commit}" - else - pip_install --no-use-pep517 --user "git+https://github.com/pytorch/audio.git@${commit}" - fi - + pip_build_and_install "git+https://github.com/pytorch/audio.git@${commit}" dist/audio } function install_torchtext() { @@ -139,8 +160,8 @@ function install_torchtext() { local text_commit data_commit=$(get_pinned_commit data) text_commit=$(get_pinned_commit text) - pip_install --no-use-pep517 --user "git+https://github.com/pytorch/data.git@${data_commit}" - pip_install --no-use-pep517 --user "git+https://github.com/pytorch/text.git@${text_commit}" + pip_build_and_install "git+https://github.com/pytorch/data.git@${data_commit}" dist/data + pip_build_and_install "git+https://github.com/pytorch/text.git@${text_commit}" dist/text } function install_torchvision() { @@ -153,7 +174,14 @@ function install_torchvision() { echo 'char* dlerror(void) { return "";}'|gcc -fpic -shared -o "${HOME}/dlerror.so" -x c - LD_PRELOAD=${orig_preload}:${HOME}/dlerror.so fi - pip_install --no-use-pep517 --user "git+https://github.com/pytorch/vision.git@${commit}" + + if [[ "${BUILD_ENVIRONMENT}" == *cuda* ]]; then + # Not sure if both are needed, but why not + export FORCE_CUDA=1 + export WITH_CUDA=1 + fi + pip_build_and_install "git+https://github.com/pytorch/vision.git@${commit}" dist/vision + if [ -n "${LD_PRELOAD}" ]; then LD_PRELOAD=${orig_preload} fi @@ -173,25 +201,73 @@ function install_torchrec_and_fbgemm() { if [[ "$BUILD_ENVIRONMENT" == *rocm* ]] ; then # install torchrec first because it installs fbgemm nightly on top of rocm fbgemm - pip_install --no-use-pep517 --user "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" + pip_build_and_install "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" dist/torchrec pip_uninstall fbgemm-gpu-nightly + # Set ROCM_HOME isn't available, use ROCM_PATH if set or /opt/rocm + ROCM_HOME="${ROCM_HOME:-${ROCM_PATH:-/opt/rocm}}" + + # Find rocm_version.h header file for ROCm version extract + rocm_version_h="${ROCM_HOME}/include/rocm-core/rocm_version.h" + if [ ! -f "$rocm_version_h" ]; then + rocm_version_h="${ROCM_HOME}/include/rocm_version.h" + fi + + # Error out if rocm_version.h not found + if [ ! -f "$rocm_version_h" ]; then + echo "Error: rocm_version.h not found in expected locations." >&2 + exit 1 + fi + + # Extract major, minor and patch ROCm version numbers + MAJOR_VERSION=$(grep 'ROCM_VERSION_MAJOR' "$rocm_version_h" | awk '{print $3}') + MINOR_VERSION=$(grep 'ROCM_VERSION_MINOR' "$rocm_version_h" | awk '{print $3}') + PATCH_VERSION=$(grep 'ROCM_VERSION_PATCH' "$rocm_version_h" | awk '{print $3}') + ROCM_INT=$((MAJOR_VERSION * 10000 + MINOR_VERSION * 100 + PATCH_VERSION)) + echo "ROCm version: $ROCM_INT" + export BUILD_ROCM_VERSION="$MAJOR_VERSION.$MINOR_VERSION" + pip_install tabulate # needed for newer fbgemm pip_install patchelf # needed for rocm fbgemm - git clone --recursive https://github.com/pytorch/fbgemm - pushd fbgemm/fbgemm_gpu - git checkout "${fbgemm_commit}" - python setup.py install \ - --package_variant=rocm \ - -DHIP_ROOT_DIR="${ROCM_PATH}" \ - -DCMAKE_C_FLAGS="-DTORCH_USE_HIP_DSA" \ - -DCMAKE_CXX_FLAGS="-DTORCH_USE_HIP_DSA" - popd + pushd /tmp + + local wheel_dir=dist/fbgemm_gpu + local found_whl=0 + for file in "${wheel_dir}"/*.whl + do + if [[ -f "${file}" ]]; then + found_whl=1 + break + fi + done + + # Build the wheel if it doesn't exist + if [ "${found_whl}" == "0" ]; then + git clone --recursive https://github.com/pytorch/fbgemm + pushd fbgemm/fbgemm_gpu + git checkout "${fbgemm_commit}" + python setup.py bdist_wheel \ + --build-variant=rocm \ + -DHIP_ROOT_DIR="${ROCM_PATH}" \ + -DCMAKE_C_FLAGS="-DTORCH_USE_HIP_DSA" \ + -DCMAKE_CXX_FLAGS="-DTORCH_USE_HIP_DSA" + popd + + # Save the wheel before cleaning up + mkdir -p dist/fbgemm_gpu + cp fbgemm/fbgemm_gpu/dist/*.whl dist/fbgemm_gpu + fi + + for file in "${wheel_dir}"/*.whl + do + pip_install_whl "${file}" + done + rm -rf fbgemm + popd else - # See https://github.com/pytorch/pytorch/issues/106971 - CUDA_PATH=/usr/local/cuda-12.1 pip_install --no-use-pep517 --user "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#egg=fbgemm-gpu&subdirectory=fbgemm_gpu" - pip_install --no-use-pep517 --user "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" + pip_build_and_install "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" dist/torchrec + pip_build_and_install "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#subdirectory=fbgemm_gpu" dist/fbgemm_gpu fi } @@ -234,7 +310,7 @@ function checkout_install_torchbench() { function install_torchao() { local commit commit=$(get_pinned_commit torchao) - pip_install --no-use-pep517 --user "git+https://github.com/pytorch/ao.git@${commit}" + pip_build_and_install "git+https://github.com/pytorch/ao.git@${commit}" dist/ao } function print_sccache_stats() { diff --git a/.ci/pytorch/run_tests.sh b/.ci/pytorch/run_tests.sh index 34ee40d7bcd0..f5ed90deef24 100755 --- a/.ci/pytorch/run_tests.sh +++ b/.ci/pytorch/run_tests.sh @@ -74,12 +74,13 @@ else fi # Environment initialization +retry pip install -qUr requirements-build.txt if [[ "$(uname)" == Darwin ]]; then # Install the testing dependencies - retry pip install -q future hypothesis ${NUMPY_PACKAGE} ${PROTOBUF_PACKAGE} pytest setuptools six typing_extensions pyyaml + retry pip install -q future hypothesis ${NUMPY_PACKAGE} ${PROTOBUF_PACKAGE} pytest else retry pip install -qr requirements.txt || true - retry pip install -q hypothesis protobuf pytest setuptools || true + retry pip install -q hypothesis protobuf pytest || true numpy_ver=1.15 case "$(python --version 2>&1)" in *2* | *3.5* | *3.6*) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 7b7e1970f72e..b16557061d11 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -201,7 +201,7 @@ fi if [[ "$BUILD_ENVIRONMENT" != *-bazel-* ]] ; then # JIT C++ extensions require ninja. - pip_install --user "ninja==1.10.2" + pip_install "ninja==1.10.2" # ninja is installed in $HOME/.local/bin, e.g., /var/lib/jenkins/.local/bin for CI user jenkins # but this script should be runnable by any user, including root export PATH="$HOME/.local/bin:$PATH" @@ -289,6 +289,12 @@ elif [[ $TEST_CONFIG == 'nogpu_AVX512' ]]; then export ATEN_CPU_CAPABILITY=avx2 fi +if [[ "${TEST_CONFIG}" == "legacy_nvidia_driver" ]]; then + # Make sure that CUDA can be initialized + (cd test && python -c "import torch; torch.rand(2, 2, device='cuda')") + export USE_LEGACY_DRIVER=1 +fi + test_python_legacy_jit() { time python test/run_test.py --include test_jit_legacy test_jit_fuser_legacy --verbose assert_git_not_dirty @@ -339,6 +345,12 @@ test_h100_symm_mem() { assert_git_not_dirty } +test_h100_cutlass_backend() { + # cutlass backend tests for H100 + TORCHINDUCTOR_CUTLASS_DIR=$(realpath "./third_party/cutlass") python test/run_test.py --include inductor/test_cutlass_backend -k "not addmm" $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + TORCHINDUCTOR_CUTLASS_DIR=$(realpath "./third_party/cutlass") python test/run_test.py --include inductor/test_cutlass_evt $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running +} + test_lazy_tensor_meta_reference_disabled() { export TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1 echo "Testing lazy tensor operations without meta reference" @@ -496,7 +508,7 @@ DYNAMO_BENCHMARK_FLAGS=() pr_time_benchmarks() { - pip_install --user "fbscribelogger" + pip_install "fbscribelogger" TEST_REPORTS_DIR=$(pwd)/test/test-reports mkdir -p "$TEST_REPORTS_DIR" @@ -1471,8 +1483,8 @@ test_bazel() { test_benchmarks() { if [[ "$BUILD_ENVIRONMENT" == *cuda* && $TEST_CONFIG != *nogpu* ]]; then - pip_install --user "pytest-benchmark==3.2.3" - pip_install --user "requests" + pip_install "pytest-benchmark==3.2.3" + pip_install "requests" BENCHMARK_DATA="benchmarks/.data" mkdir -p ${BENCHMARK_DATA} pytest benchmarks/fastrnns/test_bench.py --benchmark-sort=Name --benchmark-json=${BENCHMARK_DATA}/fastrnns_default.json --fuser=default --executor=default @@ -1600,7 +1612,13 @@ if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-baze fi if [[ "${TEST_CONFIG}" == *numpy_2* ]]; then # Install numpy-2.0.2 and compatible scipy & numba versions - python -mpip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0 + # Force re-install of pandas to avoid error where pandas checks numpy version from initial install and fails upon import + TMP_PANDAS_VERSION=$(python -c "import pandas; print(pandas.__version__)" 2>/dev/null) + if [ -n "$TMP_PANDAS_VERSION" ]; then + python -m pip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0 pandas=="$TMP_PANDAS_VERSION" --force-reinstall + else + python -m pip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0 + fi python test/run_test.py --include dynamo/test_functions.py dynamo/test_unspec.py test_binary_ufuncs.py test_fake_tensor.py test_linalg.py test_numpy_interop.py test_tensor_creation_ops.py test_torch.py torch_np/test_basic.py elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" != *perf_cpu_aarch64* ]]; then test_linux_aarch64 @@ -1654,23 +1672,19 @@ elif [[ "${TEST_CONFIG}" == *timm* ]]; then id=$((SHARD_NUMBER-1)) test_dynamo_benchmark timm_models "$id" elif [[ "${TEST_CONFIG}" == cachebench ]]; then - install_torchaudio cuda + install_torchaudio install_torchvision checkout_install_torchbench nanogpt BERT_pytorch resnet50 hf_T5 llama moco PYTHONPATH=$(pwd)/torchbench test_cachebench elif [[ "${TEST_CONFIG}" == verify_cachebench ]]; then - install_torchaudio cpu + install_torchaudio install_torchvision checkout_install_torchbench nanogpt PYTHONPATH=$(pwd)/torchbench test_verify_cachebench elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then - if [[ "${TEST_CONFIG}" == *cpu* ]]; then - install_torchaudio cpu - else - install_torchaudio cuda - fi + install_torchaudio install_torchvision - TORCH_CUDA_ARCH_LIST="8.0;8.6" install_torchao + install_torchao id=$((SHARD_NUMBER-1)) # https://github.com/opencv/opencv-python/issues/885 pip_install opencv-python==4.8.0.74 @@ -1761,6 +1775,8 @@ elif [[ "${TEST_CONFIG}" == h100_distributed ]]; then test_h100_distributed elif [[ "${TEST_CONFIG}" == "h100-symm-mem" ]]; then test_h100_symm_mem +elif [[ "${TEST_CONFIG}" == h100_cutlass_backend ]]; then + test_h100_cutlass_backend else install_torchvision install_monkeytype diff --git a/.ci/pytorch/win-arm64-build.ps1 b/.ci/pytorch/win-arm64-build.ps1 new file mode 100644 index 000000000000..2cb162b8a301 --- /dev/null +++ b/.ci/pytorch/win-arm64-build.ps1 @@ -0,0 +1,34 @@ +# If you want to rebuild, run this with $env:REBUILD=1 +# If you want to build with CUDA, run this with $env:USE_CUDA=1 +# If you want to build without CUDA, run this with $env:USE_CUDA=0 + +# Check for setup.py in the current directory +if (-not (Test-Path "setup.py")) { + Write-Host "ERROR: Please run this build script from PyTorch root directory." + exit 1 +} + +# Get the script's parent directory +$ScriptParentDir = Split-Path -Parent $MyInvocation.MyCommand.Definition + +# Set TMP_DIR and convert to Windows path +$env:TMP_DIR = Join-Path (Get-Location) "build\win_tmp" +$env:TMP_DIR_WIN = $env:TMP_DIR # Already in Windows format, no cygpath needed + +# Set final package directory with default fallback +if (-not $env:PYTORCH_FINAL_PACKAGE_DIR) { + $env:PYTORCH_FINAL_PACKAGE_DIR = "C:\w\build-results" +} + +# Create the final package directory if it doesn't exist +if (-not (Test-Path $env:PYTORCH_FINAL_PACKAGE_DIR)) { + New-Item -Path $env:PYTORCH_FINAL_PACKAGE_DIR -ItemType Directory -Force | Out-Null +} + +# Set script helpers directory +$env:SCRIPT_HELPERS_DIR = Join-Path $ScriptParentDir "win-test-helpers\arm64" + +# Run the main build script +& "$env:SCRIPT_HELPERS_DIR\build_pytorch.ps1" + +Write-Host "BUILD PASSED" diff --git a/.ci/pytorch/win-arm64-test.sh b/.ci/pytorch/win-arm64-test.sh new file mode 100644 index 000000000000..662c561aa896 --- /dev/null +++ b/.ci/pytorch/win-arm64-test.sh @@ -0,0 +1,24 @@ +#!/bin/bash +set -ex -o pipefail + +SCRIPT_PARENT_DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd ) +# shellcheck source=./common.sh +source "$SCRIPT_PARENT_DIR/common.sh" + +run_tests() { + echo Running smoke_test.py... + python ./.ci/pytorch/smoke_test/smoke_test.py --package torchonly + + echo Running test_autograd.oy, test_nn.py, test_torch.py... + cd test + + CORE_TEST_LIST=("test_autograd.py" "test_nn.py" "test_modules.py") + + for t in "${CORE_TEST_LIST[@]}"; do + echo "Running test: $t" + python "$t" --verbose --save-xml --use-pytest -vvvv -rfEsxXP -p no:xdist + done +} + +run_tests +echo "TEST PASSED" diff --git a/.ci/pytorch/win-test-helpers/arm64/build_pytorch.ps1 b/.ci/pytorch/win-test-helpers/arm64/build_pytorch.ps1 new file mode 100644 index 000000000000..29b3e913439c --- /dev/null +++ b/.ci/pytorch/win-test-helpers/arm64/build_pytorch.ps1 @@ -0,0 +1,98 @@ +# TODO: we may can use existing build_pytorch.bat for arm64 + +if ($env:DEBUG -eq "1") { + $env:BUILD_TYPE = "debug" +} else { + $env:BUILD_TYPE = "release" +} + +# This inflates our log size slightly, but it is REALLY useful to be +# able to see what our cl.exe commands are. (since you can actually +# just copy-paste them into a local Windows setup to just rebuild a +# single file.) +# log sizes are too long, but leaving this here in case someone wants to use it locally +# $env:CMAKE_VERBOSE_MAKEFILE = "1" + +$env:INSTALLER_DIR = Join-Path $env:SCRIPT_HELPERS_DIR "installation-helpers" + +cd .. + +# Environment variables +$env:SCCACHE_IDLE_TIMEOUT = "0" +$env:SCCACHE_IGNORE_SERVER_IO_ERROR = "1" +$env:CMAKE_BUILD_TYPE = $env:BUILD_TYPE +$env:CMAKE_C_COMPILER_LAUNCHER = "sccache" +$env:CMAKE_CXX_COMPILER_LAUNCHER = "sccache" +$env:libuv_ROOT = Join-Path $env:DEPENDENCIES_DIR "libuv\install" +$env:MSSdk = "1" + +if ($env:PYTORCH_BUILD_VERSION) { + $env:PYTORCH_BUILD_VERSION = $env:PYTORCH_BUILD_VERSION + $env:PYTORCH_BUILD_NUMBER = "1" +} + +$env:CMAKE_POLICY_VERSION_MINIMUM = "3.5" + +# Set BLAS type +if ($env:ENABLE_APL -eq "1") { + $env:BLAS = "APL" + $env:USE_LAPACK = "1" +} elseif ($env:ENABLE_OPENBLAS -eq "1") { + $env:BLAS = "OpenBLAS" + $env:OpenBLAS_HOME = Join-Path $env:DEPENDENCIES_DIR "OpenBLAS\install" +} + +# Change to source directory +Set-Location $env:PYTORCH_ROOT + +# Copy libuv.dll +Copy-Item -Path (Join-Path $env:libuv_ROOT "lib\Release\uv.dll") -Destination "torch\lib\uv.dll" -Force + +# Create virtual environment +python -m venv .venv +.\.venv\Scripts\Activate.ps1 +where.exe python + +# Python install dependencies +python -m pip install --upgrade pip +pip install setuptools pyyaml +pip install -r requirements.txt + +# Set after installing psutil +$env:DISTUTILS_USE_SDK = "1" + +# Print all environment variables +Get-ChildItem Env: + +# Start and inspect sccache +sccache --start-server +sccache --zero-stats +sccache --show-stats + +# Build the wheel +python setup.py bdist_wheel +if ($LASTEXITCODE -ne 0) { exit 1 } + +# Install the wheel locally +$whl = Get-ChildItem -Path "dist\*.whl" | Select-Object -First 1 +if ($whl) { + python -mpip install --no-index --no-deps $whl.FullName +} + +# Copy final wheel +robocopy "dist" "$env:PYTORCH_FINAL_PACKAGE_DIR" *.whl + +# Export test times +python tools/stats/export_test_times.py + +# Copy additional CI files +robocopy ".additional_ci_files" "$env:PYTORCH_FINAL_PACKAGE_DIR\.additional_ci_files" /E + +# Save ninja log +Copy-Item -Path "build\.ninja_log" -Destination $env:PYTORCH_FINAL_PACKAGE_DIR -Force + +# Final sccache stats and stop +sccache --show-stats +sccache --stop-server + +exit 0 diff --git a/.ci/pytorch/windows/cuda129.bat b/.ci/pytorch/windows/cuda129.bat index 77ef14921aa6..b17e6113c63e 100644 --- a/.ci/pytorch/windows/cuda129.bat +++ b/.ci/pytorch/windows/cuda129.bat @@ -37,10 +37,10 @@ IF "%CUDA_PATH_V129%"=="" ( ) IF "%BUILD_VISION%" == "" ( - set TORCH_CUDA_ARCH_LIST=7.5;8.0;8.6;9.0;10.0;12.0 + set TORCH_CUDA_ARCH_LIST=7.0;7.5;8.0;8.6;9.0;10.0;12.0 set TORCH_NVCC_FLAGS=-Xfatbin -compress-all ) ELSE ( - set NVCC_FLAGS=-D__CUDA_NO_HALF_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_90,code=compute_90 -gencode=arch=compute_100,code=compute_100 -gencode=arch=compute_120,code=compute_120 + set NVCC_FLAGS=-D__CUDA_NO_HALF_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_90,code=compute_90 -gencode=arch=compute_100,code=compute_100 -gencode=arch=compute_120,code=compute_120 ) set "CUDA_PATH=%CUDA_PATH_V129%" diff --git a/.ci/pytorch/windows/internal/smoke_test.bat b/.ci/pytorch/windows/internal/smoke_test.bat index b7463f855428..f671a9d0e0ab 100644 --- a/.ci/pytorch/windows/internal/smoke_test.bat +++ b/.ci/pytorch/windows/internal/smoke_test.bat @@ -148,14 +148,7 @@ if "%NVIDIA_GPU_EXISTS%" == "0" ( goto end ) -set BUILD_SPLIT_CUDA= -if exist "%install_root%\lib\torch_cuda_cu.lib" if exist "%install_root%\lib\torch_cuda_cpp.lib" set BUILD_SPLIT_CUDA=ON - -if "%BUILD_SPLIT_CUDA%" == "ON" ( - cl %PYTORCH_ROOT%\.ci\pytorch\test_example_code\check-torch-cuda.cpp torch_cpu.lib c10.lib torch_cuda_cu.lib torch_cuda_cpp.lib /EHsc /std:c++17 /link /INCLUDE:?warp_size@cuda@at@@YAHXZ /INCLUDE:?_torch_cuda_cu_linker_symbol_op_cuda@native@at@@YA?AVTensor@2@AEBV32@@Z -) else ( - cl %PYTORCH_ROOT%\.ci\pytorch\test_example_code\check-torch-cuda.cpp torch_cpu.lib c10.lib torch_cuda.lib /EHsc /std:c++17 /link /INCLUDE:?warp_size@cuda@at@@YAHXZ -) +cl %PYTORCH_ROOT%\.ci\pytorch\test_example_code\check-torch-cuda.cpp torch_cpu.lib c10.lib torch_cuda.lib /EHsc /std:c++17 /link /INCLUDE:?warp_size@cuda@at@@YAHXZ .\check-torch-cuda.exe if ERRORLEVEL 1 exit /b 1 diff --git a/.ci/wheel/build_wheel.sh b/.ci/wheel/build_wheel.sh index 6070e967ef82..878d6595c84c 100755 --- a/.ci/wheel/build_wheel.sh +++ b/.ci/wheel/build_wheel.sh @@ -184,7 +184,8 @@ tmp_env_name="wheel_py$python_nodot" conda create ${EXTRA_CONDA_INSTALL_FLAGS} -yn "$tmp_env_name" python="$desired_python" ${CONDA_ENV_CREATE_FLAGS} source activate "$tmp_env_name" -pip install "numpy=${NUMPY_PINNED_VERSION}" "pyyaml${PYYAML_PINNED_VERSION}" requests ninja "setuptools${SETUPTOOLS_PINNED_VERSION}" typing_extensions +retry pip install -r "${pytorch_rootdir}/requirements-build.txt" +pip install "numpy=${NUMPY_PINNED_VERSION}" "pyyaml${PYYAML_PINNED_VERSION}" requests ninja "setuptools${SETUPTOOLS_PINNED_VERSION}" typing-extensions retry pip install -r "${pytorch_rootdir}/requirements.txt" || true retry brew install libomp diff --git a/.github/actions/linux-test/action.yml b/.github/actions/linux-test/action.yml index fb46709d9b0d..32fe1d7385b1 100644 --- a/.github/actions/linux-test/action.yml +++ b/.github/actions/linux-test/action.yml @@ -126,7 +126,7 @@ runs: shell: bash continue-on-error: true run: | - python3 -m pip install psutil==5.9.1 nvidia-ml-py==11.525.84 + python3 -m pip install psutil==5.9.8 nvidia-ml-py==11.525.84 python3 -m tools.stats.monitor > usage_log.txt 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 7a71e6f2a5e4..b49cbe79f9d7 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -6c57850358f34c47802db216b0746e4e9d08a95a +b6a3368a45aaafe05f1a6a9f10c68adc5e944d9e diff --git a/.github/ci_commit_pins/fbgemm_rocm.txt b/.github/ci_commit_pins/fbgemm_rocm.txt index fa11e10ca6b8..db140a31f3fa 100644 --- a/.github/ci_commit_pins/fbgemm_rocm.txt +++ b/.github/ci_commit_pins/fbgemm_rocm.txt @@ -1 +1 @@ -5fb5024118e9bb9decf96c2b0b1a8f0010bf56be +7f1de94a4c2d14f59ad4ca84538c36084ea6b2c8 diff --git a/.github/ci_commit_pins/vllm.txt b/.github/ci_commit_pins/vllm.txt new file mode 100644 index 000000000000..07270e9f557b --- /dev/null +++ b/.github/ci_commit_pins/vllm.txt @@ -0,0 +1 @@ +b77c7d327f2a463bb9ef8be36f30e920bc066502 diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml index 5786c2aa1652..17b5f49d9ed7 100644 --- a/.github/merge_rules.yaml +++ b/.github/merge_rules.yaml @@ -76,6 +76,7 @@ - .github/ci_commit_pins/audio.txt - .github/ci_commit_pins/vision.txt - .github/ci_commit_pins/torchdynamo.txt + - .github/ci_commit_pins/vllm.txt - .ci/docker/ci_commit_pins/triton.txt approved_by: - pytorchbot @@ -491,6 +492,19 @@ - srossross - chillee - zou3519 + - guilhermeleobas + mandatory_checks_name: + - EasyCLA + - Lint + - pull + +- name: Dynamo + patterns: + - torch/_dynamo/** + - torch/csrc/dynamo/** + - test/dynamo/** + approved_by: + - guilhermeleobas mandatory_checks_name: - EasyCLA - Lint diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index ac8cb3df0ffc..a5982b63b70f 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -31,7 +31,9 @@ ciflow_push_tags: - ciflow/pull - ciflow/h100 - ciflow/h100-distributed +- ciflow/win-arm64 - ciflow/h100-symm-mem +- ciflow/h100-cutlass-backend retryable_workflows: - pull - trunk diff --git a/.github/requirements-gha-cache.txt b/.github/requirements-gha-cache.txt index 5e2819c8a836..5c691e4bf9b3 100644 --- a/.github/requirements-gha-cache.txt +++ b/.github/requirements-gha-cache.txt @@ -1,5 +1,6 @@ # This file is to cache other dependencies not specified elsewhere in: -# requirement.txt +# requirements.txt +# requirements-build.txt # docs/requirements.txt # docs/cpp/requirements.txt # functorch/docs/requirements.txt diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt index e8464f0a55ff..9c72c71523b7 100644 --- a/.github/requirements/pip-requirements-macOS.txt +++ b/.github/requirements/pip-requirements-macOS.txt @@ -16,7 +16,7 @@ packaging==23.1 parameterized==0.8.1 pillow==10.3.0 protobuf==5.29.4 -psutil==5.9.1 +psutil==5.9.8 pygments==2.15.0 pytest-cpp==2.3.0 pytest-flakefinder==1.1.0 diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 55cb02504ea4..4df6150f9765 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -22,6 +22,7 @@ LABEL_CIFLOW_PERIODIC = "ciflow/periodic" LABEL_CIFLOW_BINARIES_LIBTORCH = "ciflow/binaries_libtorch" LABEL_CIFLOW_BINARIES_WHEEL = "ciflow/binaries_wheel" +LABEL_CIFLOW_ROCM = "ciflow/rocm" @dataclass @@ -146,13 +147,35 @@ class OperatingSystem: ), ] +ROCM_SMOKE_WORKFLOWS = [ + BinaryBuildWorkflow( + os=OperatingSystem.LINUX, + package_type="manywheel", + build_variant="rocm", + build_configs=generate_binary_build_matrix.generate_wheels_matrix( + OperatingSystem.LINUX, + arches=["6.4"], + python_versions=["3.9"], + ), + ciflow_config=CIFlowConfig( + labels={ + LABEL_CIFLOW_BINARIES, + LABEL_CIFLOW_BINARIES_WHEEL, + LABEL_CIFLOW_ROCM, + }, + isolated_workflow=True, + ), + branches="main", + ), +] + LINUX_BINARY_SMOKE_WORKFLOWS = [ BinaryBuildWorkflow( os=OperatingSystem.LINUX, package_type="manywheel", build_configs=generate_binary_build_matrix.generate_wheels_matrix( OperatingSystem.LINUX, - arches=["12.6", "12.8", "12.9", "6.4"], + arches=["12.6", "12.8", "12.9"], python_versions=["3.9"], ), branches="main", @@ -387,6 +410,11 @@ def main() -> None: jinja_env.get_template("linux_binary_build_workflow.yml.j2"), S390X_BINARY_BUILD_WORKFLOWS, ), + ( + # Give rocm it's own workflow file + jinja_env.get_template("linux_binary_build_workflow.yml.j2"), + ROCM_SMOKE_WORKFLOWS, + ), ( jinja_env.get_template("linux_binary_build_workflow.yml.j2"), LINUX_BINARY_SMOKE_WORKFLOWS, diff --git a/.github/workflows/_get-changed-files.yml b/.github/workflows/_get-changed-files.yml new file mode 100644 index 000000000000..55712b065270 --- /dev/null +++ b/.github/workflows/_get-changed-files.yml @@ -0,0 +1,43 @@ +name: Get Changed Files + +on: + workflow_call: + outputs: + changed-files: + description: "List of changed files (space-separated) or '*' if not in a PR" + value: ${{ jobs.get-changed-files.outputs.changed-files }} + +jobs: + get-changed-files: + runs-on: ubuntu-latest + outputs: + changed-files: ${{ steps.get-files.outputs.changed-files }} + + steps: + - name: Get changed files + id: get-files + env: + GH_TOKEN: ${{ github.token }} + run: | + # Check if we're in a pull request context + if [ "${{ github.event_name }}" = "pull_request" ] || [ "${{ github.event_name }}" = "pull_request_target" ]; then + echo "Running in PR context" + + # Get the PR number from the github context + PR_NUMBER="${{ github.event.number }}" + + # Use gh CLI to get changed files in the PR with explicit repo + CHANGED_FILES=$(gh api repos/${{ github.repository }}/pulls/$PR_NUMBER/files --paginate --jq '.[] | select(.status != "removed") | .filename' | tr '\n' ' ' | sed 's/ $//') + + if [ -z "$CHANGED_FILES" ]; then + echo "No changed files found, setting to '*'" + CHANGED_FILES="*" + fi + + echo "Changed files: $CHANGED_FILES" + echo "changed-files=$CHANGED_FILES" >> "$GITHUB_OUTPUT" + + else + echo "Not in PR context, setting changed files to '*'" + echo "changed-files=*" >> "$GITHUB_OUTPUT" + fi diff --git a/.github/workflows/_linux-build.yml b/.github/workflows/_linux-build.yml index f1e2f917f4bc..5173425009f6 100644 --- a/.github/workflows/_linux-build.yml +++ b/.github/workflows/_linux-build.yml @@ -16,11 +16,6 @@ on: type: boolean default: true description: If set, upload generated build artifacts. - build-with-debug: - required: false - type: boolean - default: false - description: If set, build in debug mode. sync-tag: required: false type: string @@ -87,7 +82,6 @@ on: required: false type: number default: 1 - allow-reuse-old-whl: description: | If set, the build try to pull an old wheel from s3 that was built on a @@ -95,6 +89,13 @@ on: required: false type: boolean default: true + build-additional-packages: + description: | + If set, the build job will also builds these packages and saves their + wheels as artifacts + required: false + type: string + default: "" secrets: HUGGING_FACE_HUB_TOKEN: @@ -106,7 +107,6 @@ on: description: | FB app token to write to scribe endpoint - outputs: docker-image: value: ${{ jobs.build.outputs.docker-image }} @@ -225,7 +225,7 @@ jobs: MONITOR_DATA_COLLECT_INTERVAL: ${{ inputs.monitor-data-collect-interval }} run: | mkdir -p ../../usage_logs - python3 -m pip install psutil==5.9.1 dataclasses_json==0.6.7 + python3 -m pip install psutil==5.9.8 dataclasses_json==0.6.7 python3 -m tools.stats.monitor \ --log-interval "$MONITOR_LOG_INTERVAL" \ --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" \ @@ -247,8 +247,6 @@ jobs: env: BUILD_ENVIRONMENT: ${{ inputs.build-environment }} BRANCH: ${{ steps.parse-ref.outputs.branch }} - # TODO duplicated - AWS_DEFAULT_REGION: us-east-1 PR_NUMBER: ${{ github.event.pull_request.number }} SHA1: ${{ github.event.pull_request.head.sha || github.sha }} # Do not set SCCACHE_S3_KEY_PREFIX to share the cache between all build jobs @@ -260,10 +258,10 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} DOCKER_IMAGE_S390X: ${{ inputs.docker-image-name }} XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }} - DEBUG: ${{ inputs.build-with-debug && '1' || '0' }} OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }} HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }} + BUILD_ADDITIONAL_PACKAGES: ${{ inputs.build-additional-packages }} run: | START_TIME=$(date +%s) if [[ ${BUILD_ENVIRONMENT} == *"s390x"* ]]; then @@ -295,7 +293,6 @@ jobs: container_name=$(docker run \ -e BUILD_ENVIRONMENT \ -e MAX_JOBS="$(nproc --ignore=2)" \ - -e AWS_DEFAULT_REGION \ -e PR_NUMBER \ -e SHA1 \ -e BRANCH \ @@ -310,6 +307,7 @@ jobs: -e HUGGING_FACE_HUB_TOKEN \ -e SCRIBE_GRAPHQL_ACCESS_TOKEN \ -e USE_SPLIT_BUILD \ + -e BUILD_ADDITIONAL_PACKAGES \ --memory="${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}g" \ --memory-swap="${TOTAL_MEMORY_WITH_SWAP}g" \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ @@ -323,6 +321,11 @@ jobs: "${USED_IMAGE}" \ ${DOCKER_SHELL_CMD} ) + + if [[ ${BUILD_ENVIRONMENT} == *"s390x"* ]]; then + docker exec -t "${container_name}" sh -c "python3 -m pip install -r requirements.txt" + fi + docker exec -t "${container_name}" sh -c '.ci/pytorch/build.sh' END_TIME=$(date +%s) diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index 469367d4d684..1848586d3cef 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -164,6 +164,8 @@ jobs: - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG id: install-nvidia-driver uses: pytorch/test-infra/.github/actions/setup-nvidia@main + with: + driver-version: ${{ matrix.config == 'legacy_nvidia_driver' && '525.105.17' || '570.133.07' }} if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && matrix.runner != 'B200' }} - name: Setup GPU_FLAG for docker run @@ -203,7 +205,7 @@ jobs: MONITOR_LOG_INTERVAL: ${{ inputs.monitor-log-interval }} MONITOR_DATA_COLLECT_INTERVAL: ${{ inputs.monitor-data-collect-interval }} run: | - python3 -m pip install psutil==5.9.1 dataclasses_json==0.6.7 nvidia-ml-py==11.525.84 + python3 -m pip install psutil==5.9.8 dataclasses_json==0.6.7 nvidia-ml-py==11.525.84 python3 -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" diff --git a/.github/workflows/_mac-test.yml b/.github/workflows/_mac-test.yml index 8498ba5a0932..063c97e449c7 100644 --- a/.github/workflows/_mac-test.yml +++ b/.github/workflows/_mac-test.yml @@ -136,7 +136,7 @@ jobs: MONITOR_LOG_INTERVAL: ${{ inputs.monitor-log-interval }} MONITOR_DATA_COLLECT_INTERVAL: ${{ inputs.monitor-data-collect-interval }} run: | - "$VENV_PATH/bin/python3" -m pip install psutil==5.9.1 dataclasses_json==0.6.7 + "$VENV_PATH/bin/python3" -m pip install psutil==5.9.8 dataclasses_sajson==0.6.7 "$VENV_PATH/bin/python3" -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" @@ -281,7 +281,7 @@ jobs: continue-on-error: true run: | if [[ -n "$REINSTALL_BREW_MINICONDA" ]]; then - brew install miniconda + brew install --cask miniconda fi - name: Clean up disk space diff --git a/.github/workflows/_rocm-test.yml b/.github/workflows/_rocm-test.yml index 006ab43da29d..dd3790c41a9e 100644 --- a/.github/workflows/_rocm-test.yml +++ b/.github/workflows/_rocm-test.yml @@ -132,7 +132,7 @@ jobs: shell: bash continue-on-error: true run: | - python3 -m pip install psutil==5.9.1 dataclasses_json==0.6.7 + python3 -m pip install psutil==5.9.8 dataclasses_json==0.6.7 python3 -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" diff --git a/.github/workflows/_win-test.yml b/.github/workflows/_win-test.yml index 36b4e5cd753f..0c95503928fb 100644 --- a/.github/workflows/_win-test.yml +++ b/.github/workflows/_win-test.yml @@ -138,7 +138,7 @@ jobs: continue-on-error: true run: | # Windows conda doesn't have python3 binary, only python, but it's python3 - ${CONDA_RUN} python -m pip install psutil==5.9.1 dataclasses_json==0.6.7 nvidia-ml-py==11.525.84 + ${CONDA_RUN} python -m pip install psutil==5.9.8 dataclasses_json==0.6.7 nvidia-ml-py==11.525.84 ${CONDA_RUN} python -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" diff --git a/.github/workflows/_xpu-test.yml b/.github/workflows/_xpu-test.yml index de1be3115c93..177e6ca4bbe3 100644 --- a/.github/workflows/_xpu-test.yml +++ b/.github/workflows/_xpu-test.yml @@ -133,7 +133,7 @@ jobs: MONITOR_LOG_INTERVAL: ${{ inputs.monitor-log-interval }} MONITOR_DATA_COLLECT_INTERVAL: ${{ inputs.monitor-data-collect-interval }} run: | - python3 -m pip install psutil==5.9.1 dataclasses_json==0.6.7 nvidia-ml-py==11.525.84 + python3 -m pip install psutil==5.9.8 dataclasses_json==0.6.7 nvidia-ml-py==11.525.84 python3 -m tools.stats.monitor --log-interval "$MONITOR_LOG_INTERVAL" --data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" > usage_log.txt 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index c7f9b9288937..255e36ebfffa 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -50,6 +50,7 @@ jobs: runner: [linux.12xlarge] docker-image-name: [ pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11, + pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm, pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9-inductor-benchmarks, pytorch-linux-jammy-cuda12.6-cudnn9-py3.12-gcc9-inductor-benchmarks, pytorch-linux-jammy-cuda12.6-cudnn9-py3.13-gcc9-inductor-benchmarks, @@ -57,13 +58,14 @@ jobs: pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc9-inductor-benchmarks, pytorch-linux-jammy-cuda12.8-cudnn9-py3.13-gcc9-inductor-benchmarks, pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9, + pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11, pytorch-linux-jammy-py3.9-clang12, pytorch-linux-jammy-py3.11-clang12, pytorch-linux-jammy-py3.12-clang12, pytorch-linux-jammy-py3.13-clang12, - pytorch-linux-jammy-rocm-n-1-py3, pytorch-linux-jammy-rocm-n-py3, pytorch-linux-noble-rocm-n-py3, + pytorch-linux-noble-rocm-alpha-py3, pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-clang12, pytorch-linux-jammy-py3.9-gcc11, pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks, diff --git a/.github/workflows/generated-linux-binary-manywheel-main.yml b/.github/workflows/generated-linux-binary-manywheel-main.yml index 8e27aca1150b..d1e89bb6e2d8 100644 --- a/.github/workflows/generated-linux-binary-manywheel-main.yml +++ b/.github/workflows/generated-linux-binary-manywheel-main.yml @@ -182,95 +182,3 @@ jobs: runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_9-rocm6_4-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: 6.4 - GPU_ARCH_TYPE: rocm - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-rocm6_4 - build_environment: linux-binary-manywheel - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-rocm6_4-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-rocm6_4-build - - get-label-type - runs-on: linux.rocm.gpu.mi250 - timeout-minutes: 240 - env: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: 6.4 - GPU_ARCH_TYPE: rocm - SKIP_ALL_TESTS: 1 - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 - use_split_build: False - DESIRED_PYTHON: "3.9" - steps: - - name: Setup ROCm - uses: ./.github/actions/setup-rocm - - uses: actions/download-artifact@v4.1.7 - name: Download Build Artifacts - with: - name: manywheel-py3_9-rocm6_4 - path: "${{ runner.temp }}/artifacts/" - - name: Checkout PyTorch - uses: actions/checkout@v4 - with: - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - submodules: recursive - path: pytorch - show-progress: false - - name: Clean PyTorch checkout - run: | - # Remove any artifacts from the previous checkouts - git clean -fxd - working-directory: pytorch - - name: ROCm set GPU_FLAG - run: | - echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}" - - name: configure aws credentials - id: aws_creds - if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} - uses: aws-actions/configure-aws-credentials@v4 - with: - role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only - aws-region: us-east-1 - role-duration-seconds: 18000 - - name: Calculate docker image - id: calculate-docker-image - uses: pytorch/test-infra/.github/actions/calculate-docker-image@main - with: - docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} - docker-image-name: manylinux2_28-builder - custom-tag-prefix: rocm6.4 - docker-build-dir: .ci/docker - working-directory: pytorch - - name: Pull Docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@main - with: - docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - - name: Test Pytorch binary - uses: ./pytorch/.github/actions/test-pytorch-binary - env: - DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - - name: Teardown ROCm - uses: ./.github/actions/teardown-rocm diff --git a/.github/workflows/generated-linux-binary-manywheel-rocm-main.yml b/.github/workflows/generated-linux-binary-manywheel-rocm-main.yml new file mode 100644 index 000000000000..b6b63c4e38d5 --- /dev/null +++ b/.github/workflows/generated-linux-binary-manywheel-rocm-main.yml @@ -0,0 +1,137 @@ +# @generated DO NOT EDIT MANUALLY + +# Template is at: .github/templates/linux_binary_build_workflow.yml.j2 +# Generation script: .github/scripts/generate_ci_workflows.py +name: linux-binary-manywheel-rocm + + +on: + push: + branches: + - main + tags: + - 'ciflow/binaries/*' + - 'ciflow/binaries_wheel/*' + - 'ciflow/rocm/*' + workflow_dispatch: + +permissions: + id-token: write + +env: + # Needed for conda builds + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + AWS_DEFAULT_REGION: us-east-1 + BINARY_ENV_FILE: /tmp/env + BUILD_ENVIRONMENT: linux-binary-manywheel-rocm + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_NUMBER: ${{ github.event.pull_request.number }} + PYTORCH_FINAL_PACKAGE_DIR: /artifacts + PYTORCH_ROOT: /pytorch + SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + SKIP_ALL_TESTS: 0 +concurrency: + group: linux-binary-manywheel-rocm-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + get-label-type: + if: github.repository_owner == 'pytorch' + name: get-label-type + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + manywheel-py3_9-rocm6_4-build: + if: ${{ github.repository_owner == 'pytorch' }} + uses: ./.github/workflows/_binary-build-linux.yml + needs: get-label-type + with: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.4 + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False + DESIRED_PYTHON: "3.9" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build_name: manywheel-py3_9-rocm6_4 + build_environment: linux-binary-manywheel-rocm + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + manywheel-py3_9-rocm6_4-test: # Testing + if: ${{ github.repository_owner == 'pytorch' }} + needs: + - manywheel-py3_9-rocm6_4-build + - get-label-type + runs-on: linux.rocm.gpu.mi250 + timeout-minutes: 240 + env: + PYTORCH_ROOT: /pytorch + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: rocm6.4 + GPU_ARCH_VERSION: 6.4 + GPU_ARCH_TYPE: rocm + SKIP_ALL_TESTS: 1 + DOCKER_IMAGE: manylinux2_28-builder + DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + use_split_build: False + DESIRED_PYTHON: "3.9" + steps: + - name: Setup ROCm + uses: ./.github/actions/setup-rocm + - uses: actions/download-artifact@v4.1.7 + name: Download Build Artifacts + with: + name: manywheel-py3_9-rocm6_4 + path: "${{ runner.temp }}/artifacts/" + - name: Checkout PyTorch + uses: actions/checkout@v4 + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + show-progress: false + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: ROCm set GPU_FLAG + run: | + echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}" + - name: configure aws credentials + id: aws_creds + if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only + aws-region: us-east-1 + role-duration-seconds: 18000 + - name: Calculate docker image + id: calculate-docker-image + uses: pytorch/test-infra/.github/actions/calculate-docker-image@main + with: + docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} + docker-image-name: manylinux2_28-builder + custom-tag-prefix: rocm6.4 + docker-build-dir: .ci/docker + working-directory: pytorch + - name: Pull Docker image + uses: pytorch/test-infra/.github/actions/pull-docker-image@main + with: + docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Test Pytorch binary + uses: ./pytorch/.github/actions/test-pytorch-binary + env: + DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} + - name: Teardown ROCm + uses: ./.github/actions/teardown-rocm diff --git a/.github/workflows/h100-cutlass-backend.yml b/.github/workflows/h100-cutlass-backend.yml new file mode 100644 index 000000000000..82dc2ae2a394 --- /dev/null +++ b/.github/workflows/h100-cutlass-backend.yml @@ -0,0 +1,58 @@ +name: Limited CI for CUTLASS backend on H100 + +on: + pull_request: + paths: + - .github/workflows/h100-cutlass-backend.yml + workflow_dispatch: + schedule: + - cron: 22 9 * * * # every 24 hours about 2:22am PDT + push: + tags: + - ciflow/h100-cutlass-backend/* + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +permissions: + id-token: write + contents: read + +jobs: + + get-label-type: + if: github.repository_owner == 'pytorch' + name: get-label-type + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + + linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-cutlass-backend: + name: linux-jammy-cuda12.8-py3.10-gcc11-sm90-cutlass-backend + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-cutlass-backend + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 + cuda-arch-list: '9.0' + test-matrix: | + { include: [ + { config: "h100_cutlass_backend", shard: 1, num_shards: 1, runner: "linux.aws.h100", owners: ["oncall:pt2"] }, + ]} + secrets: inherit + + linux-jammy-cuda12_8-py3_10-gcc11-sm90-test: + name: linux-jammy-cuda12.8-py3.10-gcc11-sm90-cutlass-backend + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-cutlass-backend + with: + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-cutlass-backend + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-cutlass-backend.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm90-build-cutlass-backend.outputs.test-matrix }} + secrets: inherit diff --git a/.github/workflows/inductor-nightly.yml b/.github/workflows/inductor-nightly.yml index d8dc7146fda1..c17a4ed6341a 100644 --- a/.github/workflows/inductor-nightly.yml +++ b/.github/workflows/inductor-nightly.yml @@ -48,6 +48,7 @@ jobs: { config: "dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench", shard: 1, num_shards: 2, runner: "linux.8xlarge.amx" }, { config: "dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench", shard: 2, num_shards: 2, runner: "linux.8xlarge.amx" }, ]} + build-additional-packages: "vision audio torchao" secrets: inherit linux-jammy-cpu-py3_9-gcc11-nightly-dynamo-benchmarks-test: diff --git a/.github/workflows/inductor-perf-compare.yml b/.github/workflows/inductor-perf-compare.yml index 25191643b359..628f62424012 100644 --- a/.github/workflows/inductor-perf-compare.yml +++ b/.github/workflows/inductor-perf-compare.yml @@ -43,6 +43,7 @@ jobs: { config: "inductor_timm_perf_compare", shard: 2, num_shards: 2, runner: "linux.aws.a100" }, { config: "inductor_torchbench_perf_compare", shard: 1, num_shards: 1, runner: "linux.aws.a100" }, ]} + build-additional-packages: "vision audio fbgemm torchao" secrets: inherit test: diff --git a/.github/workflows/inductor-perf-test-nightly-aarch64.yml b/.github/workflows/inductor-perf-test-nightly-aarch64.yml index ed04d88eb127..e16c8be79130 100644 --- a/.github/workflows/inductor-perf-test-nightly-aarch64.yml +++ b/.github/workflows/inductor-perf-test-nightly-aarch64.yml @@ -116,6 +116,7 @@ jobs: { config: "inductor_torchbench_perf_cpu_aarch64", shard: 15, num_shards: 15, runner: "linux.arm64.m7g.metal" }, ]} selected-test-configs: ${{ inputs.benchmark_configs }} + build-additional-packages: "vision audio torchao" secrets: inherit diff --git a/.github/workflows/inductor-perf-test-nightly-h100.yml b/.github/workflows/inductor-perf-test-nightly-h100.yml index c94996f58002..ab651e081b7c 100644 --- a/.github/workflows/inductor-perf-test-nightly-h100.yml +++ b/.github/workflows/inductor-perf-test-nightly-h100.yml @@ -86,6 +86,11 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + # Use a bigger runner here because CUDA_ARCH 9.0 is only built for H100 + # or newer GPUs, so it doesn't benefit much from existing compiler cache + # from trunk. Also use a memory-intensive runner here because memory is + # usually the bottleneck + runner: linux.12xlarge.memory build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '9.0' @@ -114,6 +119,7 @@ jobs: { config: "inductor_torchbench_perf_cuda_h100", shard: 9, num_shards: 9, runner: "linux.aws.h100" }, ]} selected-test-configs: ${{ inputs.benchmark_configs }} + build-additional-packages: "vision audio fbgemm torchao" secrets: inherit test-periodically: diff --git a/.github/workflows/inductor-perf-test-nightly-x86.yml b/.github/workflows/inductor-perf-test-nightly-x86.yml index 0466576658d4..62234e5f499a 100644 --- a/.github/workflows/inductor-perf-test-nightly-x86.yml +++ b/.github/workflows/inductor-perf-test-nightly-x86.yml @@ -98,6 +98,7 @@ jobs: { config: "inductor_torchbench_perf_cpu_x86", shard: 4, num_shards: 4, runner: "linux.24xl.spr-metal" }, ]} selected-test-configs: ${{ inputs.benchmark_configs }} + build-additional-packages: "vision audio torchao" secrets: inherit linux-jammy-cpu-py3_9-gcc11-inductor-test-nightly-freezing: diff --git a/.github/workflows/inductor-perf-test-nightly.yml b/.github/workflows/inductor-perf-test-nightly.yml index 015204473339..9fd81a5a05c9 100644 --- a/.github/workflows/inductor-perf-test-nightly.yml +++ b/.github/workflows/inductor-perf-test-nightly.yml @@ -86,6 +86,8 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + # Every bit to make perf run faster helps + runner: linux.12xlarge.memory build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' @@ -112,6 +114,7 @@ jobs: { config: "cachebench", shard: 2, num_shards: 2, runner: "linux.aws.a100" }, ]} selected-test-configs: ${{ inputs.benchmark_configs }} + build-additional-packages: "vision audio fbgemm torchao" secrets: inherit test-nightly: diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index 2e16c2e403fb..d3f1ff1f1dae 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -58,6 +58,7 @@ jobs: { config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, ]} + build-additional-packages: "vision audio fbgemm torchao" secrets: inherit linux-jammy-cuda12_8-py3_10-gcc9-periodic-dynamo-benchmarks-test: @@ -125,6 +126,7 @@ jobs: { include: [ { config: "inductor_torchbench_smoketest_perf", shard: 1, num_shards: 1, runner: "linux.aws.a100" }, ]} + build-additional-packages: "vision audio fbgemm torchao" secrets: inherit linux-jammy-cuda12_8-py3_10-gcc9-inductor-smoke-test: @@ -159,6 +161,7 @@ jobs: { config: "cpu_inductor_freezing_avx2_timm", shard: 1, num_shards: 2, runner: "linux.10xlarge.avx2" }, { config: "cpu_inductor_freezing_avx2_timm", shard: 2, num_shards: 2, runner: "linux.10xlarge.avx2" }, ]} + build-additional-packages: "vision audio torchao" secrets: inherit linux-jammy-cpu-py3_9-gcc11-periodic-dynamo-benchmarks-test: @@ -195,6 +198,7 @@ jobs: { config: "aot_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, ]} + build-additional-packages: "vision audio fbgemm torchao" secrets: inherit linux-jammy-cuda12_8-py3_10-gcc9-inductor-test: @@ -240,6 +244,7 @@ jobs: { config: "dynamic_cpu_aot_inductor_amp_freezing_torchbench", shard: 1, num_shards: 2, runner: "linux.8xlarge.amx" }, { config: "dynamic_cpu_aot_inductor_amp_freezing_torchbench", shard: 2, num_shards: 2, runner: "linux.8xlarge.amx" }, ]} + build-additional-packages: "vision audio torchao" secrets: inherit linux-jammy-cpu-py3_9-gcc11-inductor-test: diff --git a/.github/workflows/inductor-rocm.yml b/.github/workflows/inductor-rocm.yml index 4241854aa327..b1bb7972d67d 100644 --- a/.github/workflows/inductor-rocm.yml +++ b/.github/workflows/inductor-rocm.yml @@ -7,7 +7,6 @@ on: - release/* tags: - ciflow/inductor-rocm/* - - ciflow/inductor/* workflow_dispatch: concurrency: diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index e6fc7aa65431..721572f1807b 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -62,6 +62,7 @@ jobs: { config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, { config: "inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, ]} + build-additional-packages: "vision audio fbgemm torchao" secrets: inherit linux-jammy-cuda12_8-py3_10-gcc9-inductor-test: @@ -94,6 +95,7 @@ jobs: { config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, { config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" }, ]} + build-additional-packages: "vision audio torchao" secrets: inherit linux-jammy-cpu-py3_9-gcc11-inductor-test: diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d0a2fda509ef..476195ab5eec 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -27,9 +27,29 @@ jobs: issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} + get-changed-files: + if: github.repository_owner == 'pytorch' + name: Get changed files + uses: ./.github/workflows/_get-changed-files.yml + lintrunner-clang: uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main - needs: get-label-type + needs: [get-label-type, get-changed-files] + # Only run if there are changed files relevant to clangtidy / clangformat + if: | + github.repository_owner == 'pytorch' && ( + needs.get-changed-files.outputs.changed-files == '*' || + contains(needs.get-changed-files.outputs.changed-files, '.h') || + contains(needs.get-changed-files.outputs.changed-files, '.cpp') || + contains(needs.get-changed-files.outputs.changed-files, '.cc') || + contains(needs.get-changed-files.outputs.changed-files, '.cxx') || + contains(needs.get-changed-files.outputs.changed-files, '.hpp') || + contains(needs.get-changed-files.outputs.changed-files, '.hxx') || + contains(needs.get-changed-files.outputs.changed-files, '.cu') || + contains(needs.get-changed-files.outputs.changed-files, '.cuh') || + contains(needs.get-changed-files.outputs.changed-files, '.mm') || + contains(needs.get-changed-files.outputs.changed-files, '.metal') + ) with: timeout: 120 runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" @@ -40,25 +60,61 @@ jobs: submodules: true ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} script: | - export ADDITIONAL_LINTRUNNER_ARGS="--take CLANGTIDY,CLANGFORMAT --all-files" + CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" + if [ "$CHANGED_FILES" = "*" ]; then + export ADDITIONAL_LINTRUNNER_ARGS="--take CLANGTIDY,CLANGFORMAT --all-files" + else + export ADDITIONAL_LINTRUNNER_ARGS="--take CLANGTIDY,CLANGFORMAT $CHANGED_FILES" + fi export CLANG=1 .github/scripts/lintrunner.sh + # NOTE: mypy needs its own job because it depends on --all-files, without assessing all files it sometimes + # fails to find types when it should + lintrunner-mypy: + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + needs: [get-label-type, get-changed-files] + # Only run if there are changed files relevant to mypy + if: | + github.repository_owner == 'pytorch' && ( + needs.get-changed-files.outputs.changed-files == '*' || + contains(needs.get-changed-files.outputs.changed-files, '.py') || + contains(needs.get-changed-files.outputs.changed-files, '.pyi') + ) + with: + timeout: 120 + runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" + docker-image: ci-image:pytorch-linux-jammy-linter + # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout + # to run git rev-parse HEAD~:.ci/docker when a new image is needed + fetch-depth: 0 + submodules: true + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + script: | + CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" + echo "Running mypy" + ADDITIONAL_LINTRUNNER_ARGS="--take MYPY --all-files" .github/scripts/lintrunner.sh + lintrunner-noclang: uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main - needs: get-label-type + needs: [get-label-type, get-changed-files] with: timeout: 120 runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" - docker-image: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-linter + docker-image: ci-image:pytorch-linux-jammy-linter # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout # to run git rev-parse HEAD~:.ci/docker when a new image is needed fetch-depth: 0 submodules: true ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} script: | - export ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT --all-files" - .github/scripts/lintrunner.sh + CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" + echo "Running all other linters" + if [ "$CHANGED_FILES" = '*' ]; then + ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY --all-files" .github/scripts/lintrunner.sh + else + ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY ${CHANGED_FILES}" .github/scripts/lintrunner.sh + fi quick-checks: uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main @@ -261,6 +317,7 @@ jobs: check-latest: false cache: pip cache-dependency-path: | + **/requirements-build.txt **/requirements.txt - name: Setup Min Python version if: matrix.test_type != 'older_python_version' @@ -271,6 +328,7 @@ jobs: check-latest: false cache: pip cache-dependency-path: | + **/requirements-build.txt **/requirements.txt - name: Install torch if: matrix.test_type == 'with_torch' diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 70fea3c8cc1c..7bb1ff9296ab 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -83,6 +83,10 @@ jobs: repo-owner: triton-lang branch: main pin-folder: .ci/docker/ci_commit_pins + - repo-name: vllm + repo-owner: vllm-project + branch: main + pin-folder: .github/ci_commit_pins # Allow this to be triggered on either a schedule or on workflow_dispatch to allow for easier testing if: github.repository_owner == 'pytorch' && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') steps: diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 0882019d5115..976fb241c99f 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -51,6 +51,67 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} + linux-jammy-cuda12_4-py3_10-gcc11-sm89-build: + name: linux-jammy-cuda12.4-py3.10-gcc11-sm89 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-cuda12.4-py3.10-gcc11-sm89 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11 + cuda-arch-list: 8.9 + test-matrix: | + { include: [ + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" }, + ]} + secrets: inherit + + linux-jammy-cuda12_4-py3_10-gcc11-sm89-test: + name: linux-jammy-cuda12.4-py3.10-gcc11-sm89 + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-cuda12_4-py3_10-gcc11-sm89-build + - target-determination + with: + build-environment: linux-jammy-cuda12.4-py3.10-gcc11-sm89 + docker-image: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-sm89-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-sm89-build.outputs.test-matrix }} + secrets: inherit + + linux-jammy-cuda12_4-py3_10-gcc11-build: + name: linux-jammy-cuda12.4-py3.10-gcc11 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-cuda12.4-py3.10-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11 + test-matrix: | + { include: [ + { config: "legacy_nvidia_driver", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "legacy_nvidia_driver", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "legacy_nvidia_driver", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "legacy_nvidia_driver", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "legacy_nvidia_driver", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + ]} + secrets: inherit + + linux-jammy-cuda12_4-py3_10-gcc11-test: + name: linux-jammy-cuda12.4-py3.10-gcc11 + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-cuda12_4-py3_10-gcc11-build + - target-determination + with: + build-environment: linux-jammy-cuda12.4-py3.10-gcc11 + docker-image: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-build.outputs.test-matrix }} + secrets: inherit + linux-jammy-cuda12_8-py3_10-gcc11-build: name: linux-jammy-cuda12.8-py3.10-gcc11 uses: ./.github/workflows/_linux-build.yml @@ -96,7 +157,6 @@ jobs: { config: "multigpu", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] }, { config: "multigpu", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] }, ]} - build-with-debug: false secrets: inherit linux-jammy-cuda12_8-py3_9-gcc9-test: @@ -117,7 +177,6 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9 - build-with-debug: true test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu", owners: ["oncall:debug-build"] }, diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 53a4f6357e5c..be0bdc527cc1 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -315,21 +315,6 @@ jobs: test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }} secrets: inherit - linux-jammy-py3-clang12-mobile-build: - name: linux-jammy-py3-clang12-mobile-build - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3-clang12-mobile-build - docker-image-name: ci-image:pytorch-linux-jammy-py3-clang15-asan - build-generates-artifacts: false - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 1 }, - ]} - secrets: inherit - linux-jammy-cuda12_8-cudnn9-py3_9-clang12-build: name: linux-jammy-cuda12.8-cudnn9-py3.9-clang12 uses: ./.github/workflows/_linux-build.yml diff --git a/.github/workflows/test-h100.yml b/.github/workflows/test-h100.yml index 40eff83ba58d..7e4a818c3528 100644 --- a/.github/workflows/test-h100.yml +++ b/.github/workflows/test-h100.yml @@ -37,7 +37,7 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runner: "linux.12xlarge" + runner: linux.12xlarge.memory build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '9.0' diff --git a/.github/workflows/win-arm64-build-test.yml b/.github/workflows/win-arm64-build-test.yml new file mode 100644 index 000000000000..627a43b56bf7 --- /dev/null +++ b/.github/workflows/win-arm64-build-test.yml @@ -0,0 +1,187 @@ +name: windows-arm64-build-test + +on: + push: + tags: + - ciflow/win-arm64/* + +env: + GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} + PYTHON_VERSION: "3.12" + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + DOWNLOADS_DIR: c:\temp\downloads + DEPENDENCIES_DIR: c:\temp\dependencies + ENABLE_APL: 1 + ENABLE_OPENBLAS: 0 + BUILD_TYPE: release + +permissions: + id-token: write + contents: read + +jobs: + build: + # Don't run on forked repos. + if: github.repository_owner == 'pytorch' + runs-on: "windows-11-arm64-preview" + timeout-minutes: 240 + steps: + - name: configure aws credentials + id: aws_creds + uses: aws-actions/configure-aws-credentials@v4 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_sscache + aws-region: us-east-1 + role-duration-seconds: 18000 + + - name: Enable long paths + shell: cmd + run: | + git config --system --get core.longpaths || echo "core.longpaths is not set, setting it now" + git config --system core.longpaths true + + - name: Git checkout PyTorch + uses: actions/checkout@v4 + with: + path: pytorch + submodules: recursive + + - name: Bootstrap Python + shell: cmd + run: | + "pytorch/.ci/pytorch/windows/arm64/bootstrap_python.bat" + + - name: Parse ref + id: parse-ref + shell: bash + run: python pytorch/.github/scripts/parse_ref.py + + - name: Get workflow job id + shell: bash + id: get-job-id + run: | + set -eux + python pytorch/.github/scripts/get_workflow_job_id.py "${GITHUB_RUN_ID}" "${RUNNER_NAME}" + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Bootstrap APL + shell: cmd + run: | + "pytorch/.ci/pytorch/windows/arm64/bootstrap_apl.bat" + + - name: Bootstrap Rust + shell: cmd + run: | + "pytorch/.ci/pytorch/windows/arm64/bootstrap_rust.bat" + + - name: Bootstrap sccache + shell: cmd + run: | + "pytorch/.ci/pytorch/windows/arm64/bootstrap_sccache.bat" + + - name: Bootstrap Libuv + shell: cmd + run: | + "pytorch/.ci/pytorch/windows/arm64/bootstrap_libuv.bat" + + - name: Build + id: build + shell: cmd + env: + PYTORCH_FINAL_PACKAGE_DIR: C:/${{ github.run_id }}/build-results/ + BRANCH: ${{ steps.parse-ref.outputs.branch }} + BUILD_WHEEL: 1 + MAX_JOBS: 8 + PYTHON_VERSION: "3.12" + SCCACHE_BUCKET: "ossci-compiler-cache" + SCCACHE_S3_KEY_PREFIX: ${{ github.workflow }} + SCCACHE_REGION: us-east-1 + VC_PRODUCT: "BuildTools" + VC_VERSION: "" + ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" + AWS_DEFAULT_REGION: us-east-1 + USE_CUDA: '0' + USE_XPU: '0' + OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }} + run: | + cd pytorch + call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" arm64 + powershell -ExecutionPolicy Bypass -File ".ci/pytorch/win-arm64-build.ps1" + + - name: Upload artifacts + uses: actions/upload-artifact@v4.4.0 + if: always() + with: + name: torch-wheel-win-arm64-py3-12 + retention-days: 14 + if-no-files-found: error + path: C:\${{ github.run_id }}\build-results + + test: + if: github.repository_owner == 'pytorch' + strategy: + fail-fast: false + runs-on: "windows-11-arm64-preview" + needs: build + steps: + - name: Enable long paths + shell: cmd + run: | + git config --system --get core.longpaths || echo "core.longpaths is not set, setting it now" + git config --system core.longpaths true + + - name: Git checkout PyTorch + uses: actions/checkout@v4 + with: + path: pytorch + submodules: recursive + + - name: Bootstrap Python + shell: cmd + run: | + "pytorch/.ci/pytorch/windows/arm64/bootstrap_python.bat" + + - name: Bootstrap Rust + shell: cmd + run: | + "pytorch/.ci/pytorch/windows/arm64/bootstrap_rust.bat" + + - name: Get workflow job id + shell: bash + id: get-job-id + run: | + set -eux + python pytorch/.github/scripts/get_workflow_job_id.py "${GITHUB_RUN_ID}" "${RUNNER_NAME}" + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Download Build Artifacts + uses: actions/download-artifact@v4.1.7 + with: + name: torch-wheel-win-arm64-py3-12 + path: C:\${{ github.run_id }}\build-results + + - name: Test + id: test + shell: cmd + env: + USE_CUDA: '0' + INSTALL_WINDOWS_SDK: 1 + PYTHON_VERSION: "3.12" + VC_PRODUCT: "BuildTools" + AWS_DEFAULT_REGION: us-east-1 + GITHUB_REPOSITORY: ${{ github.repository }} + GITHUB_WORKFLOW: ${{ github.workflow }} + GITHUB_JOB: ${{ github.job }} + GITHUB_RUN_ID: ${{ github.run_id }} + GITHUB_RUN_NUMBER: ${{ github.run_number }} + GITHUB_RUN_ATTEMPT: ${{ github.run_attempt }} + JOB_ID: ${{ steps.get-job-id.outputs.job-id }} + JOB_NAME: ${{ steps.get-job-id.outputs.job-name }} + PYTORCH_FINAL_PACKAGE_DIR: C:/${{ github.run_id }}/build-results/ + run: | + mkdir "%PYTORCH_FINAL_PACKAGE_DIR%" + call pytorch/.ci/pytorch/windows/arm64/bootstrap_tests.bat + set GIT_BASH=C:\Program Files\Git\usr\bin\bash.exe + "%GIT_BASH%" -c "bash --noprofile --norc .ci/pytorch/win-arm64-test.sh" \ No newline at end of file diff --git a/.lintrunner.toml b/.lintrunner.toml index 7e9b7ebd5d2c..04664378d8bf 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -500,7 +500,7 @@ include_patterns = [ '**/*.h', ] exclude_patterns = [ - 'c10/macros/Macros.h', + 'torch/headeronly/macros/Macros.h', ] command = [ 'python3', @@ -523,7 +523,7 @@ include_patterns = [ '**/*.h', ] exclude_patterns = [ - 'c10/macros/Macros.h', + 'torch/headeronly/macros/Macros.h', ] command = [ 'python3', @@ -1162,14 +1162,9 @@ exclude_patterns = [ # These files are all grandfathered in, feel free to remove from this list # as necessary # NOTE: remove the patterns in the order they are listed - 'aten/**', - 'aten/src/ATen/native/**', - 'aten/src/ATen/native/q*/**', 'aten/src/ATen/native/[a-pA-P]*/**', 'aten/src/ATen/[a-mA-M]*/**', 'test/**', - 'test/[a-hA-h]*/**', - 'torch/distributed/tensor/**', ] init_command = [ 'python3', @@ -1605,7 +1600,10 @@ is_formatter = true # the same line, merge conflicts should not arise in git or hg [[linter]] code = 'MERGE_CONFLICTLESS_CSV' -include_patterns = ['benchmarks/dynamo/ci_expected_accuracy/*.csv'] +include_patterns = [ + 'benchmarks/dynamo/ci_expected_accuracy/*.csv', + 'benchmarks/dynamo/pr_time_benchmarks/expected_results.csv', +] command = [ 'python3', 'tools/linter/adapters/no_merge_conflict_csv_linter.py', diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000000..2c67fb1981b7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,12 @@ +repos: + - repo: local + hooks: + - id: lintrunner + name: Run Lintrunner in an isolated venv before every push. The first run may be slow... + entry: python scripts/run_lintrunner.py # wrapper below + language: python # pre‑commit manages venv for the wrapper + additional_dependencies: [] # wrapper handles lintrunner install + always_run: true + stages: [pre-push] # fire only on pre‑push + pass_filenames: false # Lintrunner gets no per‑file args + verbose: true # stream output as it is produced...allegedly anyways diff --git a/CMakeLists.txt b/CMakeLists.txt index 99c0b9e0ea0c..63a2f74404c1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1190,10 +1190,6 @@ if(APPLE) append_cxx_flag_if_supported("-Wno-missing-braces" CMAKE_CXX_FLAGS) endif() -if(USE_XPU) - string(APPEND CMAKE_CXX_FLAGS " -DUSE_XPU") -endif() - if(EMSCRIPTEN) string( APPEND @@ -1245,6 +1241,7 @@ if(USE_MIMALLOC AND USE_MIMALLOC_ON_MKL) endif() # ---[ Main build +add_subdirectory(torch/headeronly) # headeronly headers add_subdirectory(c10) add_subdirectory(caffe2) diff --git a/CODEOWNERS b/CODEOWNERS index 2982b405c3df..9e01c96c4e9c 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -136,7 +136,7 @@ torch/profiler/ @sraikund16 test/functorch/test_aotdispatch.py @ezyang @Chillee # Dataloader -torch/utils/data/ @divyanshk @ramanishsingh +torch/utils/data/ @divyanshk @ramanishsingh @scotts # hipify torch/utils/hipify/ @jeffdaily @jithunnair-amd diff --git a/Dockerfile b/Dockerfile index 9f23712af2b8..63b8c5bcb47a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -33,7 +33,7 @@ RUN case ${TARGETPLATFORM} in \ *) MINICONDA_ARCH=x86_64 ;; \ esac && \ curl -fsSL -v -o ~/miniconda.sh -O "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-${MINICONDA_ARCH}.sh" -COPY requirements.txt . +COPY requirements.txt requirements-build.txt . # Manually invoke bash on miniconda script per https://github.com/conda/conda/issues/10431 RUN chmod +x ~/miniconda.sh && \ bash ~/miniconda.sh -b -p /opt/conda && \ diff --git a/README.md b/README.md index e566f1356d9c..62e3b9ea4937 100644 --- a/README.md +++ b/README.md @@ -294,14 +294,12 @@ Install PyTorch ```bash export CMAKE_PREFIX_PATH="${CONDA_PREFIX:-'$(dirname $(which conda))/../'}:${CMAKE_PREFIX_PATH}" -python -m pip install -r requirements.txt python -m pip install --no-build-isolation -v -e . ``` **On macOS** ```bash -python -m pip install -r requirements.txt python -m pip install --no-build-isolation -v -e . ``` @@ -520,7 +518,7 @@ on [our website](https://pytorch.org/get-started/previous-versions). ## Getting Started -Three-pointers to get you started: +Three pointers to get you started: - [Tutorials: get you started with understanding and using PyTorch](https://pytorch.org/tutorials/) - [Examples: easy to understand PyTorch code across all domains](https://github.com/pytorch/examples) - [The API Reference](https://pytorch.org/docs/) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index af8fea252947..3355d45eafa5 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -458,7 +458,7 @@ if(LAPACK_FOUND) # would not need this at all), some of our libraries (magma in particular) # backend to CPU BLAS/LAPACK implementations, and so it is very important # we get the *right* implementation, because even if the symbols are the - # same, LAPACK implementions may have different calling conventions. + # same, LAPACK implementations may have different calling conventions. # This caused https://github.com/pytorch/pytorch/issues/7353 # # We do NOT do this on Linux, since we just rely on torch_cpu to diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index cac0e31eaad4..ded7743c4d86 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -14,7 +14,9 @@ #include #ifdef USE_FBGEMM +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include +C10_DIAGNOSTIC_POP() #endif // USE_FBGEMM #if defined(__aarch64__) && !defined(C10_MOBILE) #include @@ -27,7 +29,7 @@ namespace { These const variables defined the fp32 precisions for different backend We have "generic", "cuda", "mkldnn" backend now and we can choose fp32 prevision from "ieee", "tf32", "bf16" and "none". The "ieee" precision means - IEEE standard floating point format "tf32" and "bf16" means we are allowed to + IEEE standard floating point format, "tf32" and "bf16" means we are allowed to use "tf32" or "bf16" as internal computation data types for fp32 computations. And "none" means it is override-able by parent's node @@ -40,7 +42,7 @@ namespace { */ const std::map> _fp32_precisions = { {"generic", {{"ieee", "tf32", "bf16", "none"}}}, - {"mkldnn", {{"ieee", "bf16", "none"}}}, + {"mkldnn", {{"ieee", "tf32", "bf16", "none"}}}, {"cuda", {{"ieee", "tf32", "none"}}}}; // Check whether the backend and op are legal @@ -76,7 +78,9 @@ void check_fp32_prec_backend_and_op( C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){ TORCH_WARN_ONCE( - "This API is going to be deprecated, please see " + "Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' " + "or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, " + "torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see " "https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices" ); } @@ -368,6 +372,9 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const { invalid = invalid || (float32Precision("mkldnn", "matmul") == "bf16" && float32_matmul_precision != at::Float32MatmulPrecision::MEDIUM); + invalid = invalid || + (float32Precision("mkldnn", "matmul") == "tf32" && + float32_matmul_precision != at::Float32MatmulPrecision::HIGH); TORCH_CHECK( !invalid, "PyTorch is checking the matmul precision without a specific backend name,", @@ -401,7 +408,7 @@ void Context::setFloat32MatmulPrecision(const std::string &s) { } else if (s_ == "high") { float32_matmul_precision = at::Float32MatmulPrecision::HIGH; setFloat32Precision("cuda", "matmul", "tf32"); - setFloat32Precision("mkldnn", "matmul", "ieee"); + setFloat32Precision("mkldnn", "matmul", "tf32"); return true; } else if (s_ == "medium") { float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM; diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index f25e68001ff4..bdb5cae907cd 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -69,37 +69,41 @@ DLDataType getDLDataType(const Tensor& t) { case ScalarType::Float8_e4m3fn: case ScalarType::Float8_e4m3fnuz: case ScalarType::Float8_e8m0fnu: - TORCH_CHECK(false, "float8 types are not supported by dlpack"); + TORCH_CHECK_BUFFER(false, "float8 types are not supported by dlpack"); break; case ScalarType::Float4_e2m1fn_x2: - TORCH_CHECK(false, "float4 types are not supported by dlpack"); + TORCH_CHECK_BUFFER(false, "float4 types are not supported by dlpack"); break; case ScalarType::QInt8: case ScalarType::QUInt8: case ScalarType::QInt32: case ScalarType::QUInt4x2: case ScalarType::QUInt2x4: - TORCH_CHECK(false, "QUInt/QInt types are not supported by dlpack"); + TORCH_CHECK_BUFFER(false, "QUInt/QInt types are not supported by dlpack"); break; case ScalarType::Bits1x8: case ScalarType::Bits2x4: case ScalarType::Bits4x2: case ScalarType::Bits8: case ScalarType::Bits16: - TORCH_CHECK(false, "Bit types are not supported by dlpack"); + TORCH_CHECK_BUFFER(false, "Bit types are not supported by dlpack"); break; case ScalarType::Undefined: - TORCH_CHECK(false, "Undefined is not a valid ScalarType"); + TORCH_CHECK_BUFFER(false, "Undefined is not a valid ScalarType"); case ScalarType::NumOptions: - TORCH_CHECK(false, "NumOptions is not a valid ScalarType"); + TORCH_CHECK_BUFFER(false, "NumOptions is not a valid ScalarType"); } return dtype; } -static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) { +DLDevice torchDeviceToDLDevice(at::Device device) { DLDevice ctx; - ctx.device_id = static_cast(static_cast(device_id)); - switch (tensor.device().type()) { + + ctx.device_id = (device.is_cuda() || device.is_privateuseone()) + ? static_cast(static_cast(device.index())) + : 0; + + switch (device.type()) { case DeviceType::CPU: ctx.device_type = DLDeviceType::kDLCPU; break; @@ -120,8 +124,7 @@ static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) { break; case DeviceType::XPU: ctx.device_type = DLDeviceType::kDLOneAPI; - ctx.device_id = - at::detail::getXPUHooks().getGlobalIdxFromDevice(tensor.device()); + ctx.device_id = at::detail::getXPUHooks().getGlobalIdxFromDevice(device); break; case DeviceType::MAIA: ctx.device_type = DLDeviceType::kDLMAIA; @@ -130,44 +133,46 @@ static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) { ctx.device_type = DLDeviceType::kDLExtDev; break; default: - TORCH_CHECK(false, "Cannot pack tensors on " + tensor.device().str()); + TORCH_CHECK_BUFFER(false, "Cannot pack tensors on " + device.str()); } + return ctx; } -static Device getATenDevice(const DLDevice& ctx, void* data) { - switch (ctx.device_type) { +static Device getATenDevice(DLDeviceType type, c10::DeviceIndex index, void* data = nullptr) { + switch (type) { case DLDeviceType::kDLCPU: return at::Device(DeviceType::CPU); #ifndef USE_ROCM // if we are compiled under HIP, we cannot do cuda case DLDeviceType::kDLCUDA: - return at::Device(DeviceType::CUDA, static_cast(ctx.device_id)); + return at::Device(DeviceType::CUDA, index); #endif case DLDeviceType::kDLOpenCL: - return at::Device(DeviceType::OPENCL, static_cast(ctx.device_id)); + return at::Device(DeviceType::OPENCL, index); case DLDeviceType::kDLROCM: #ifdef USE_ROCM // this looks funny, we need to return CUDA here to masquerade - return at::Device(DeviceType::CUDA, static_cast(ctx.device_id)); + return at::Device(DeviceType::CUDA, index); #else - return at::Device(DeviceType::HIP, static_cast(ctx.device_id)); + return at::Device(DeviceType::HIP, index); #endif case DLDeviceType::kDLOneAPI: + TORCH_CHECK(data != nullptr, "Can't get ATen device for XPU without XPU data."); return at::detail::getXPUHooks().getDeviceFromPtr(data); case DLDeviceType::kDLMAIA: - return at::Device(DeviceType::MAIA, static_cast(ctx.device_id)); + return at::Device(DeviceType::MAIA, index); case DLDeviceType::kDLExtDev: - return at::Device(DeviceType::PrivateUse1, static_cast(ctx.device_id)); + return at::Device(DeviceType::PrivateUse1, index); default: - TORCH_CHECK( - false, "Unsupported device_type: ", std::to_string(ctx.device_type)); + TORCH_CHECK_BUFFER( + false, "Unsupported device_type: ", std::to_string(type)); } } ScalarType toScalarType(const DLDataType& dtype) { ScalarType stype = ScalarType::Undefined; - TORCH_CHECK(dtype.lanes == 1, "ATen does not support lanes != 1"); + TORCH_CHECK_BUFFER(dtype.lanes == 1, "ATen does not support lanes != 1"); switch (dtype.code) { case DLDataTypeCode::kDLUInt: switch (dtype.bits) { @@ -184,7 +189,7 @@ ScalarType toScalarType(const DLDataType& dtype) { stype = ScalarType::UInt64; break; default: - TORCH_CHECK( + TORCH_CHECK_BUFFER( false, "Unsupported kUInt bits ", std::to_string(dtype.bits)); } break; @@ -203,7 +208,7 @@ ScalarType toScalarType(const DLDataType& dtype) { stype = ScalarType::Long; break; default: - TORCH_CHECK( + TORCH_CHECK_BUFFER( false, "Unsupported kInt bits ", std::to_string(dtype.bits)); } break; @@ -219,7 +224,7 @@ ScalarType toScalarType(const DLDataType& dtype) { stype = ScalarType::Double; break; default: - TORCH_CHECK( + TORCH_CHECK_BUFFER( false, "Unsupported kFloat bits ", std::to_string(dtype.bits)); } break; @@ -229,7 +234,7 @@ ScalarType toScalarType(const DLDataType& dtype) { stype = ScalarType::BFloat16; break; default: - TORCH_CHECK( + TORCH_CHECK_BUFFER( false, "Unsupported kFloat bits ", std::to_string(dtype.bits)); } break; @@ -245,7 +250,7 @@ ScalarType toScalarType(const DLDataType& dtype) { stype = ScalarType::ComplexDouble; break; default: - TORCH_CHECK( + TORCH_CHECK_BUFFER( false, "Unsupported kFloat bits ", std::to_string(dtype.bits)); } break; @@ -255,12 +260,12 @@ ScalarType toScalarType(const DLDataType& dtype) { stype = ScalarType::Bool; break; default: - TORCH_CHECK( + TORCH_CHECK_BUFFER( false, "Unsupported kDLBool bits ", std::to_string(dtype.bits)); } break; default: - TORCH_CHECK(false, "Unsupported code ", std::to_string(dtype.code)); + TORCH_CHECK_BUFFER(false, "Unsupported code ", std::to_string(dtype.code)); } return stype; } @@ -314,11 +319,7 @@ T* toDLPackImpl(const Tensor& src) { atDLMTensor->tensor.manager_ctx = atDLMTensor; atDLMTensor->tensor.deleter = &deleter; atDLMTensor->tensor.dl_tensor.data = view.data_ptr(); - c10::DeviceIndex device_id = 0; - if (src.is_cuda() || src.is_privateuseone()) { - device_id = src.get_device(); - } - atDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id); + atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device()); atDLMTensor->tensor.dl_tensor.ndim = static_cast(src.dim()); atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src); atDLMTensor->tensor.dl_tensor.shape = view.sizes().data(); @@ -346,7 +347,7 @@ at::Tensor fromDLPackImpl(T* src, std::function deleter) { } DLTensor& dl_tensor = src->dl_tensor; - Device device = getATenDevice(dl_tensor.device, dl_tensor.data); + Device device = getATenDevice(dl_tensor.device.device_type, dl_tensor.device.device_id, dl_tensor.data); ScalarType stype = toScalarType(dl_tensor.dtype); if (!dl_tensor.strides) { @@ -388,4 +389,35 @@ Tensor fromDLPackVersioned(DLManagedTensorVersioned* src, std::function(src, std::move(deleter)); } +Tensor maybeCopyTensor( + const Tensor& data, + std::optional optional_dl_device, + std::optional copy) { + bool force_copy = copy.has_value() && *copy; + bool force_move = copy.has_value() && !*copy; + + if (optional_dl_device.has_value()) { + auto device = at::getATenDevice( + optional_dl_device->device_type, + static_cast(optional_dl_device->device_id)); + + if (device != data.device()) { + TORCH_CHECK_VALUE( + !force_move, + "cannot move (i.e. copy=False) tensor from ", + data.device(), + " to ", + device, + " without copying."); + return data.to(device); + } + } + + if (force_copy) { + return data.clone(); + } + + return data; +} + } // namespace at diff --git a/aten/src/ATen/DLConvertor.h b/aten/src/ATen/DLConvertor.h index abc996db5ab4..b1c2eaa2d6ea 100644 --- a/aten/src/ATen/DLConvertor.h +++ b/aten/src/ATen/DLConvertor.h @@ -4,7 +4,7 @@ #include #include -// this convertor will: +// this converter will: // 1) take a Tensor object and wrap it in the DLPack tensor // 2) take a dlpack tensor and convert it to the ATen Tensor @@ -21,6 +21,16 @@ TORCH_API Tensor fromDLPackVersioned( TORCH_API DLDataType getDLDataType(const Tensor& t); TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id); +// Copies the Tensor if there's a device mismatch or copy is forced. +// This should be used before actually creating the DLPack capsule. +TORCH_API Tensor maybeCopyTensor( + const Tensor& data, + std::optional optional_dl_device, + std::optional copy); + +// Converts the given at::Device into a DLDevice. +TORCH_API DLDevice torchDeviceToDLDevice(at::Device device); + // This trait class is used for retrieving different attributes, such as the // PyCapsule names and conversion functions for both DLPack tensor classes: // `DLManagedTensor` and `DLManagedTensorVersioned`. diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp index 117a9eef6eb6..123d87b30414 100644 --- a/aten/src/ATen/FunctionalInverses.cpp +++ b/aten/src/ATen/FunctionalInverses.cpp @@ -233,8 +233,8 @@ Tensor FunctionalInverses::slice_Tensor_inverse(const Tensor& base, const Tensor // NOLINTNEXTLINE(performance-unnecessary-value-param) Tensor FunctionalInverses::split_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, c10::SymInt split_size, int64_t dim) { - // It would be nice if this logic could be re-used from autograd's split_backward(), but I don't think it can. - // For functionalization, we have only have one of the tensors from the TensorList outputed by split(), and we want to layer i + // It would be nice if this logic could be reused from autograd's split_backward(), but I don't think it can. + // For functionalization, we have only have one of the tensors from the TensorList outputted by split(), and we want to layer i // on top of the base tensor. // For autograd, we have all of the tensors outputted by split() and we just want to stack them. dim = at::maybe_wrap_dim(dim, base.dim()); diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index ff4e2b562278..7d5e4e84e861 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -286,11 +286,11 @@ void FunctionalTensorWrapper::storage_resize_(const c10::SymInt& new_size) { // storage resizing is severely limited: we only support resizing either to zero, or from zero bytes. TORCH_CHECK(new_size == 0 || curr_storage_size == 0, "new_size: ", new_size, ". curr_storage_size: ", curr_storage_size); // The "functionalization rule" for storage resizing is a giant no-op, mainly because we don't want - // resize_() calls to actualy emit any ops in the functional graph. + // resize_() calls to actually emit any ops in the functional graph. // How does it work? // Resizing up (old size == 0): // We do nothing in this case. - // The expection is that for the user code to be valid, the next op that should run against the current tensor "x" + // The expectation is that for the user code to be valid, the next op that should run against the current tensor "x" // will be a x.copy_(y) (or similar), that will fully overwrite the data of x. // If there are any outstanding aliases of x, we expect them not to be used until after the copy_() call // (otherwise the eager code would be invalid), @@ -327,7 +327,7 @@ void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) { // We're also no longer re-generate "b" fully from "a" anymore, since "a" refers to a slice of "b"'s data. // // This is probably fixable in theory, but: - // - the fix would likey complicated the functionalization logic quite a bit. + // - the fix would likely complicated the functionalization logic quite a bit. // - the primary use case for resize_() today is resizing zero-sized tensors in out= variants of operators // - resize_() also can give you weird results today if you try to resize_() a weirdly strided tensor. // @@ -344,7 +344,7 @@ void FunctionalTensorWrapper::maybe_replace_storage(const Tensor& other) { set_sizes_and_strides(value_.sizes(), value_.strides()); refresh_numel(); // (Technically we should be guaranteed that the tensor was already contiguous, - // since it's guaranteed not to have been a view. Doesnt hurt to run though) + // since it's guaranteed not to have been a view. Doesn't hurt to run though) refresh_contiguous(); // Swapping out the storage of a tensor (aka from a resize_() call) will update the sizes and strides of the tensor, // so we need to record the fact that metadata was mutated. @@ -819,7 +819,7 @@ void setFunctionalizationReapplyViewsTLS(bool reapply_views) { // This function will "functionalize" it. // That is, it will call the operator, but removing any intermediate views/mutations // that are performed inside of it. -// This is useful for LTC/XLA, which would like to re-use some of our composite kernels +// This is useful for LTC/XLA, which would like to reuse some of our composite kernels // from pytorch core but not have to worry about the view ops that they might call. // e.g. at::block_diag void functionalize_op_helper(const c10::OperatorHandle& op, torch::jit::Stack* stack) { diff --git a/aten/src/ATen/LegacyBatchedFallback.cpp b/aten/src/ATen/LegacyBatchedFallback.cpp index d44d92c239f2..f2b527302a97 100644 --- a/aten/src/ATen/LegacyBatchedFallback.cpp +++ b/aten/src/ATen/LegacyBatchedFallback.cpp @@ -218,7 +218,7 @@ static Tensor safeStack(TensorList tensors) { // is possible for the backward function to return an undefined grad for some // grad_input for each example. In that case, we return an undefined grad. // - // It is theoretically posssible for *some* of the examples to produce an + // It is theoretically possible for *some* of the examples to produce an // undefined grad (a kernel could peek at the gradient values and return an // undefined tensor if it determines the gradient is full of zeros). We // could handle this by treating the undefined grad as a zero-filled tensor diff --git a/aten/src/ATen/LegacyVmapTransforms.h b/aten/src/ATen/LegacyVmapTransforms.h index 97729b3254e7..be6cf1b697a2 100644 --- a/aten/src/ATen/LegacyVmapTransforms.h +++ b/aten/src/ATen/LegacyVmapTransforms.h @@ -140,7 +140,7 @@ struct TORCH_API VmapPhysicalView { // mapping a physical tensor to a new logical tensor (BatchedTensor) VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const; - // Maps a logical shape to a physical shape by pre-pending the batch + // Maps a logical shape to a physical shape by prepending the batch // sizes to the logical shape. VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const; diff --git a/aten/src/ATen/MapAllocator.cpp b/aten/src/ATen/MapAllocator.cpp index be10641aa271..63a278050e8a 100644 --- a/aten/src/ATen/MapAllocator.cpp +++ b/aten/src/ATen/MapAllocator.cpp @@ -299,7 +299,7 @@ MapAllocator::MapAllocator(WithFd, std::string_view filename, int fd, int flags, ::close(fd); TORCH_CHECK(false, "unable to stretch file <", filename_, "> to the right size: ", c10::utils::str_error(last_err), " (", last_err, ")"); } -/* on macOS write returns with errno 45 (Opperation not supported) when used +/* on macOS write returns with errno 45 (Operation not supported) when used * with a file descriptor obtained via shm_open */ #ifndef __APPLE__ diff --git a/aten/src/ATen/NestedTensorImpl.cpp b/aten/src/ATen/NestedTensorImpl.cpp index 647b2f1685d1..63bd867f9022 100644 --- a/aten/src/ATen/NestedTensorImpl.cpp +++ b/aten/src/ATen/NestedTensorImpl.cpp @@ -211,7 +211,7 @@ NestedTensorImpl::NestedTensorImpl( } // assume contiguous, `nested_strides` and `offsets` -// can be infered from `nested_sizes` +// can be inferred from `nested_sizes` NestedTensorImpl::NestedTensorImpl( const at::Tensor& buffer, const at::Tensor& nested_sizes) diff --git a/aten/src/ATen/NestedTensorImpl.h b/aten/src/ATen/NestedTensorImpl.h index f40684ce0ba2..cddf37df34a5 100644 --- a/aten/src/ATen/NestedTensorImpl.h +++ b/aten/src/ATen/NestedTensorImpl.h @@ -32,7 +32,7 @@ struct TORCH_API NestedTensorImpl : public c10::TensorImpl { at::Tensor nested_strides, at::Tensor storage_offsets); // assume contiguous, `nested_strides` and `offsets` - // can be infered from `nested_sizes` + // can be inferred from `nested_sizes` explicit NestedTensorImpl( const at::Tensor& buffer, const at::Tensor& nested_sizes); diff --git a/aten/src/ATen/Parallel.h b/aten/src/ATen/Parallel.h index 917524419f9a..b55dad02f347 100644 --- a/aten/src/ATen/Parallel.h +++ b/aten/src/ATen/Parallel.h @@ -93,12 +93,12 @@ ident: identity for binary combination function sf. sf(ident, x) needs to return x. f: function for reduction over a chunk. f needs to be of signature scalar_t -f(int64_t partial_begin, int64_t partial_end, scalar_t identifiy) +f(int64_t partial_begin, int64_t partial_end, scalar_t identify) sf: function to combine two partial results. sf needs to be of signature scalar_t sf(scalar_t x, scalar_t y) -For example, you might have a tensor of 10000 entires and want to sum together +For example, you might have a tensor of 10000 entries and want to sum together all the elements. Parallel_reduce with a grain_size of 2500 will then allocate an intermediate result tensor with 4 elements. Then it will execute the function "f" you provide and pass the beginning and end index of these chunks, so diff --git a/aten/src/ATen/ScalarOps.cpp b/aten/src/ATen/ScalarOps.cpp index 693fb46e639f..da4f7a35a2f4 100644 --- a/aten/src/ATen/ScalarOps.cpp +++ b/aten/src/ATen/ScalarOps.cpp @@ -8,7 +8,28 @@ namespace at { namespace { template inline void fill_inplace(Tensor& self, const Scalar& value_scalar) { - auto value = value_scalar.to(); + scalar_t value{}; + + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) { + // relaxed float cast: allow inf similar to the torch.tensor constructor + // + // without this, we had the following divergence: + // torch.tensor(1123581321.0, dtype=torch.float16) + // => tensor(inf, dtype=torch.float16) + // torch.ops.aten.scalar_tensor.default(1123581321, dtype=torch.float16) + // => RuntimeError: value cannot be converted to type at::Half without overflow + + value = static_cast(value_scalar.to()); + } else { + value = value_scalar.to(); + } + scalar_t* dptr = static_cast(self.data_ptr()); *dptr = value; } diff --git a/aten/src/ATen/TensorIndexing.h b/aten/src/ATen/TensorIndexing.h index d9d8554abc79..a487589833e8 100644 --- a/aten/src/ATen/TensorIndexing.h +++ b/aten/src/ATen/TensorIndexing.h @@ -252,7 +252,7 @@ inline Tensor applySelect( // Note: `size >= -index` is not equivalent to `size > -1 - index` if index // is INT64_MIN For std::numeric_limits::min() result of unary // minus is undefined by the standard but in practice is equal to self. On - // the other hand, indexing wraping is valid for all negative int64_t + // the other hand, indexing wrapping is valid for all negative int64_t // values, as x[INT64_MIN] is the same as x[INT64_MAX] TORCH_CHECK_INDEX( size.sym_gt(-1 - index) @@ -315,10 +315,17 @@ inline void recordTensorIndex( const Tensor& tensor, std::vector& outIndices, int64_t* dim_ptr) { - // TODO: check scalarType - outIndices.resize(*dim_ptr + 1); - outIndices[*dim_ptr] = tensor; - (*dim_ptr)++; + if (outIndices.empty()) { + outIndices.resize(*dim_ptr + 1); + outIndices[*dim_ptr] = tensor; + } else { + outIndices.push_back(tensor); + } + if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) { + *dim_ptr += tensor.dim(); + } else { + *dim_ptr += 1; + } } inline c10::List<::std::optional> typeConvertIndices( @@ -458,13 +465,23 @@ inline Tensor handleDimInMultiDimIndexing( original_tensor_device, prev_dim_result_sizes); (*dim_ptr)++; + if (!outIndices.empty()) { + outIndices.resize(outIndices.size() + 1); + } return result; } else if (index.is_ellipsis()) { - (*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr); + auto ellipsis_ndims = original_tensor.dim() - *specified_dims_ptr; + (*dim_ptr) += ellipsis_ndims; + if (!outIndices.empty()) { + outIndices.resize(outIndices.size() + ellipsis_ndims); + } return prev_dim_result; } else if (index.is_none()) { Tensor result = prev_dim_result.unsqueeze(*dim_ptr); (*dim_ptr)++; + if (!outIndices.empty()) { + outIndices.resize(outIndices.size() + 1); + } return result; } else if (index.is_boolean()) { Tensor result = prev_dim_result.unsqueeze(*dim_ptr); @@ -560,6 +577,10 @@ inline Tensor applySlicing( inline Tensor dispatch_index( const Tensor& self, std::vector&& indices) { + // Remove trailing null elements from indices + while (!indices.empty() && !indices.back().defined()) { + indices.pop_back(); + } return self.index(impl::typeConvertIndices(self, std::move(indices))); } @@ -567,6 +588,10 @@ inline Tensor dispatch_index_put_( Tensor& self, std::vector&& indices, const Tensor& value) { + // Remove trailing null elements from indices + while (!indices.empty() && !indices.back().defined()) { + indices.pop_back(); + } return self.index_put_( impl::typeConvertIndices(self, std::move(indices)), value); } diff --git a/aten/src/ATen/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp index 32f0f1e2defe..9096cbfc68eb 100644 --- a/aten/src/ATen/TensorIterator.cpp +++ b/aten/src/ATen/TensorIterator.cpp @@ -208,7 +208,7 @@ bool TensorIteratorConfig::is_tensor_const(size_t idx) { // same strides are increasing. If dimensions are non-increasing, we move on to the next input to break the tie. // // Instead of applying rule 4 for tie breaking, we could move on to the next tensor directly. This would result in possibly -// losing the correct permuation of the first tensor if there are permuted trivial dimensions, but could potentially +// losing the correct permutation of the first tensor if there are permuted trivial dimensions, but could potentially // improve traversal order of the second tensor. We chose the former option to better propagate channels last layout // for example for a tensor with the sizes N1H1 // These rules result in the intuitive behavior that in most cases recovers permutation of either the first argument (if all @@ -244,7 +244,7 @@ void TensorIteratorBase::reorder_dimensions() { // initialize perm with n-1, n-2, ..., 1, 0 std::iota(perm_.rbegin(), perm_.rend(), 0); - // Reordering dimensions changes iteraton order + // Reordering dimensions changes iteration order if (enforce_linear_iteration_) { permute_dimensions(perm_); return; diff --git a/aten/src/ATen/TensorIterator.h b/aten/src/ATen/TensorIterator.h index 0e49151969bd..d8eebd4c06a4 100644 --- a/aten/src/ATen/TensorIterator.h +++ b/aten/src/ATen/TensorIterator.h @@ -388,7 +388,7 @@ struct TORCH_API TensorIteratorBase : public impl::MetaBase { /// Return scalar value from original_tensor_base if it is defined. When /// common_dtype is Half, casting scalar input to common_dtype might overflow. - /// If the scalar is aleady given in the type of Half, then return scalar + /// If the scalar is already given in the type of Half, then return scalar /// value from tensor_base. template T original_scalar_value(int64_t arg) { @@ -502,7 +502,7 @@ struct TORCH_API TensorIteratorBase : public impl::MetaBase { /// kernels bool can_use_32bit_indexing() const; - /// An "iteratable" object that recursively splits this iterator into + /// An "iterable" object that recursively splits this iterator into /// sub-iterators that can use 32-bit indexing. SplitUntil32Bit with_32bit_indexing() const; @@ -878,7 +878,7 @@ class TORCH_API TensorIteratorConfig final { // Sets the enforce_linear_iteration_ flag, which is false by default. // If true, iteration goes in the same order as a C-contiguous tensor - // is layed out in memory. i.e. last dimension iterates fastest. + // is laid out in memory. i.e. last dimension iterates fastest. // // This iteration order can be less efficient and may even prevent // vectorization. So only use if the correctness of your kernel depends on it. diff --git a/aten/src/ATen/TensorSubclassLikeUtils.h b/aten/src/ATen/TensorSubclassLikeUtils.h index 49d430f6d3e4..515642a0c51d 100644 --- a/aten/src/ATen/TensorSubclassLikeUtils.h +++ b/aten/src/ATen/TensorSubclassLikeUtils.h @@ -78,7 +78,7 @@ inline bool areAnyOptionalTensorSubclassLike( // NOTE: This function expects a scalar tensor of boolean dtype. // Eg. // Non-Composite Compliant Pattern : (t == 0).all().item() -// Composite Compliant Patter : is_salar_tensor_true((t == 0).all()) +// Composite Compliant Pattern : is_salar_tensor_true((t == 0).all()) inline bool is_scalar_tensor_true(const Tensor& t) { TORCH_INTERNAL_ASSERT(t.dim() == 0) TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool) diff --git a/aten/src/ATen/TensorUtils.cpp b/aten/src/ATen/TensorUtils.cpp index 1636bbcb6f75..34cb5329de6a 100644 --- a/aten/src/ATen/TensorUtils.cpp +++ b/aten/src/ATen/TensorUtils.cpp @@ -378,9 +378,9 @@ inline static std::optional computeStride_impl( (TORCH_GUARD_OR_TRUE(sym_ne(oldshape[tensor_d - 1], 1)) && TORCH_GUARD_OR_TRUE(sym_ne(oldstride[tensor_d - 1], tensor_numel * chunk_base_stride)))) { // We want to accumulate stuff in view_numel until view_numel == tensor_numel, if we do not - // know if that is satisfied we keep accumalating. For example if view_numel = 1 and tensor_numel = u1, + // know if that is satisfied we keep accumulating. For example if view_numel = 1 and tensor_numel = u1, // we want to take that path, view_numel will become u0. Next iteration if u0==u1 we want to stop. - // Thats why we use TORCH_GUARD_OR_TRUE below. + // That's why we use TORCH_GUARD_OR_TRUE below. // we use TORCH_GUARD_OR_FALSE and not TORCH_GUARD_OR_TRUE when comparing newshape[view_d] ==1 because // if we know view_numel < tensor_numel is false, we want to stop. Unless we know for sure newshape[view_d]==1 diff --git a/aten/src/ATen/TracerMode.h b/aten/src/ATen/TracerMode.h index 8ba62640fe65..d0d4c93a84f5 100644 --- a/aten/src/ATen/TracerMode.h +++ b/aten/src/ATen/TracerMode.h @@ -27,7 +27,7 @@ // ops (ops being called by other ops). After the intermediate op call // finishes it's set back to the original `TracingState` object. // -// The `TracingState` obect in TLS can also be read/written via its Python +// The `TracingState` object in TLS can also be read/written via its Python // binding in `python_tracer.cpp`, and `get/setTracingState()` C++ APIs, // which are also exposed as `TORCH_API`. // diff --git a/aten/src/ATen/ZeroTensorFallback.cpp b/aten/src/ATen/ZeroTensorFallback.cpp index 329216cf3789..06ab82accaf2 100644 --- a/aten/src/ATen/ZeroTensorFallback.cpp +++ b/aten/src/ATen/ZeroTensorFallback.cpp @@ -95,7 +95,7 @@ namespace at { m.impl("clone", torch::CppFunction::makeFallthrough()); m.impl("dot", torch::CppFunction::makeFallthrough()); m.impl("vdot", torch::CppFunction::makeFallthrough()); - // The functions in the list below have a specific registeration in native_functions.yaml and + // The functions in the list below have a specific registration in native_functions.yaml and // do not use the fallback. // m.impl("mul.Tensor", torch::CppFunction::makeFallthrough()); // m.impl("add.Tensor", torch::CppFunction::makeFallthrough()); diff --git a/aten/src/ATen/autocast_mode.h b/aten/src/ATen/autocast_mode.h index a222b8924bac..655b2343d5d5 100644 --- a/aten/src/ATen/autocast_mode.h +++ b/aten/src/ATen/autocast_mode.h @@ -377,7 +377,7 @@ Keep it simple for now by assuming only one such flag is present in the argument list. If I ever need a function with more than flag I'll figure out something else. The policy is: -If the user has explicity specified a dtype, respect it. +If the user has explicitly specified a dtype, respect it. Otherwise, set it to the autocast type. ********************************************************/ diff --git a/aten/src/ATen/cuda/CachingHostAllocator.cpp b/aten/src/ATen/cuda/CachingHostAllocator.cpp index b5e5f84cde13..34aa15d0c06c 100644 --- a/aten/src/ATen/cuda/CachingHostAllocator.cpp +++ b/aten/src/ATen/cuda/CachingHostAllocator.cpp @@ -162,7 +162,7 @@ struct CUDACachingHostAllocatorImpl } bool pinned_use_background_threads() override { - return c10::CachingAllocator::AcceleratorAllocatorConfig:: + return c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: pinned_use_background_threads(); } @@ -258,7 +258,7 @@ DECLARE_HOST_ALLOCATOR( CUDACachingHostAllocator, CUDACachingHostAllocatorImpl, raw_local_deleter, - caching_host_allocator); + caching_host_allocator) REGISTER_HOST_ALLOCATOR(at::kCUDA, &caching_host_allocator) diff --git a/aten/src/ATen/dlpack.h b/aten/src/ATen/dlpack.h index 5d0234b5653e..82c066821118 100644 --- a/aten/src/ATen/dlpack.h +++ b/aten/src/ATen/dlpack.h @@ -199,7 +199,7 @@ typedef struct { * `byte_offset` field should be used to point to the beginning of the data. * * Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow, - * TVM, perhaps others) do not adhere to this 256 byte aligment requirement + * TVM, perhaps others) do not adhere to this 256 byte alignment requirement * on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed * (after which this note will be updated); at the moment it is recommended * to not rely on the data pointer being correctly aligned. diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 4b66b30b62e7..d58d436c511d 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -158,6 +158,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(kron); OP_DECOMPOSE(l1_loss); m.impl("layer_norm", native::layer_norm_symint); + m.impl("_fused_rms_norm", native::rms_norm_composite); OP_DECOMPOSE2(ldexp, Tensor); OP_DECOMPOSE2(less_equal, Tensor ); OP_DECOMPOSE2(less, Tensor ); diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index cfeb67bef3bd..d323e54a95ab 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -2453,7 +2453,7 @@ TORCH_IMPL_FUNC(linalg_qr_out)(const Tensor& A, // geqrf requires m x n workspace input that is modified in-place // We try to use Q. If it doesn't fit, we try to use R - // If m > n and compute_q==false, it won't fit into Q or R, so we neet to create an auxiliary tensor + // If m > n and compute_q==false, it won't fit into Q or R, so we need to create an auxiliary tensor Tensor QR; if (compute_q && Q.size(-1) == n) { QR = Q; @@ -4095,7 +4095,7 @@ Tensor linalg_vander_symint( const auto n = N.value_or(shape.back()); TORCH_CHECK(n > 1, "N must be greater than 1."); - // Append cumprod of the oher 0...n-1 powers + // Append cumprod of the other 0...n-1 powers shape.push_back(n - 1); auto result = at::cumprod(x_.unsqueeze(-1).expand_symint(shape), -1); // The row of ones diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index 1e2e664fc030..79dbe7353e15 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -202,7 +202,7 @@ void gemm( float *c, int64_t ldc) { internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); #if AT_MKLDNN_ENABLED() - if (mkldnn_bf32_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) { + if (mkldnn_reduced_f32_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) { return; } #endif diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index 4cd46f3b0028..3d388194ea49 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -36,8 +36,10 @@ #endif #ifdef USE_FBGEMM +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include #include +C10_DIAGNOSTIC_POP() #endif namespace { diff --git a/aten/src/ATen/native/DilatedMaxPool2d.cpp b/aten/src/ATen/native/DilatedMaxPool2d.cpp index 218a673d0a34..641e9f14dd71 100644 --- a/aten/src/ATen/native/DilatedMaxPool2d.cpp +++ b/aten/src/ATen/native/DilatedMaxPool2d.cpp @@ -54,7 +54,7 @@ bool ceil_mode) { TORCH_CHECK((input.ndimension() == 3 || input.ndimension() == 4), "non-empty 3D or 4D (batch mode) tensor expected for input"); } else { - TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast, Contiguous"); + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); } /* sizes */ @@ -130,7 +130,7 @@ const Tensor& indices) { TORCH_CHECK((input.ndimension() == 3 || input.ndimension() == 4), "non-empty 3D or 4D (batch mode) tensor expected for input"); } else { - TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast, Contiguous"); + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); } /* sizes */ diff --git a/aten/src/ATen/native/DilatedMaxPool3d.cpp b/aten/src/ATen/native/DilatedMaxPool3d.cpp index 458e2c032b09..23d77cb21072 100644 --- a/aten/src/ATen/native/DilatedMaxPool3d.cpp +++ b/aten/src/ATen/native/DilatedMaxPool3d.cpp @@ -63,7 +63,7 @@ void max_pool3d_with_indices_out_cpu_template( TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5), "non-empty 4D or 5D (batch mode) tensor expected for input"); } else { - TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast3d, Contiguous"); + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast3d, Contiguous"); } const int64_t nslices = input.size(-4); @@ -158,7 +158,7 @@ Tensor& max_pool3d_with_indices_backward_out_cpu_template( TORCH_CHECK((input.ndimension() == 4 || input.ndimension() == 5), "non-empty 4D or 5D (batch mode) tensor expected for input"); } else { - TORCH_CHECK(false, "Unsupport memory format. Supports only ChannelsLast3d, Contiguous"); + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast3d, Contiguous"); } const int64_t nslices = input.size(-4); diff --git a/aten/src/ATen/native/DistributionTemplates.h b/aten/src/ATen/native/DistributionTemplates.h index c6013b6fbae5..21a15b80c9c8 100644 --- a/aten/src/ATen/native/DistributionTemplates.h +++ b/aten/src/ATen/native/DistributionTemplates.h @@ -28,13 +28,13 @@ namespace at::native::templates { // ==================================================== Random ======================================================== // The purpose of `update_from` and `update_to` is to find the closest valid int64_t number that can be used as actual `from`. -// The current implementation of `random_` uses uint64_t arithmetics and casts the result to the target dtype(scalar_t). +// The current implementation of `random_` uses uint64_t arithmetic and casts the result to the target dtype(scalar_t). // This casting can result in generating numbers that happen to be greater or equal to `to` value. For instance: // // auto actual = torch::empty({3, 3}, torch::half); // actual.random_(0, 65504); // -// If random's uint64_t arithmetics produces 65503 as a random value after casting to torch::half it becomes 65504 +// If random's uint64_t arithmetic produces 65503 as a random value after casting to torch::half it becomes 65504 // and violates the requirement that random value must be less than `to`. To resolve this issue `update_from` and `update_to` // moves `from` to the right and `to` to the left to the next closest value that won't go outside [from, to) after casting to // the target dtype. For `to` = 65504 it moves left for (1 << (log2(to) - 11 + 1)) = 32 and becomes 65472, which is previous diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp index a38730b3388d..150970edc507 100644 --- a/aten/src/ATen/native/EmbeddingBag.cpp +++ b/aten/src/ATen/native/EmbeddingBag.cpp @@ -14,8 +14,10 @@ #include #ifdef USE_FBGEMM +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include #include +C10_DIAGNOSTIC_POP() #else #include #endif diff --git a/aten/src/ATen/native/GridSampler.cpp b/aten/src/ATen/native/GridSampler.cpp index efdc151bf68e..0ca8ec2a3a88 100644 --- a/aten/src/ATen/native/GridSampler.cpp +++ b/aten/src/ATen/native/GridSampler.cpp @@ -86,7 +86,7 @@ namespace { for (const auto d : c10::irange(out_D)) { for (const auto h : c10::irange(out_H)) { for (const auto w : c10::irange(out_W)) { - // get the corresponding input x, y, z co-ordinates from grid + // get the corresponding input x, y, z coordinates from grid const scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW; scalar_t ix = *grid_ptr_NDHW; scalar_t iy = grid_ptr_NDHW[grid_sCoor]; @@ -285,7 +285,7 @@ namespace { for (const auto d : c10::irange(out_D)) { for (const auto h : c10::irange(out_H)) { for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NDHW += gGrid_sW /* grad_grid is contiguous */ ) { - // get the corresponding input x, y, z co-ordinates from grid + // get the corresponding input x, y, z coordinates from grid const scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW; scalar_t ix = *grid_ptr_NDHW; scalar_t iy = grid_ptr_NDHW[grid_sCoor]; @@ -496,7 +496,7 @@ static Tensor _grid_sampler_2d_cpu_quantized( uint8_t* inp_ptr_N = inp_ptr + n * inp_sN; for (const auto h : c10::irange(out_H)) { for (const auto w : c10::irange(out_W)) { - // get the corresponding input x, y, z co-ordinates from grid + // get the corresponding input x, y, z coordinates from grid float* grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; float x = *grid_ptr_NHW; float y = grid_ptr_NHW[grid_sCoor]; @@ -599,7 +599,7 @@ Tensor _grid_sampler_2d_cpu_fallback(const Tensor& input, const Tensor& grid, const scalar_t *inp_ptr_N = inp_ptr + n * inp_sN; for (const auto h : c10::irange(out_H)) { for (const auto w : c10::irange(out_W)) { - // get the corresponding input x, y, z co-ordinates from grid + // get the corresponding input x, y, z coordinates from grid const scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; scalar_t x = *grid_ptr_NHW; scalar_t y = grid_ptr_NHW[grid_sCoor]; @@ -771,7 +771,7 @@ _grid_sampler_2d_cpu_fallback_backward(const Tensor& grad_output, scalar_t *gGrid_ptr_NHW = gGrid_ptr + n * gGrid_sN; for (const auto h : c10::irange(out_H)) { for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NHW += gGrid_sW /* grad_grid is contiguous */ ) { - // get the corresponding input x, y co-ordinates from grid + // get the corresponding input x, y coordinates from grid const scalar_t *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; scalar_t x = *grid_ptr_NHW; scalar_t y = grid_ptr_NHW[grid_sCoor]; diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index a372c5f0c7e5..b261da5fe54e 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -1068,7 +1068,7 @@ inline scalar_t calc_igammac(scalar_t a, scalar_t x) { * result at the boundary * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for * Large Parameter (see DLMF 8.12.4 [igam1]) - * - if x > 1.1 and x < a, using the substraction from the regularized lower + * - if x > 1.1 and x < a, using the subtraction from the regularized lower * incomplete gamma * - otherwise, calculate the series from [igam2] eq (5) */ @@ -1148,7 +1148,7 @@ scalar_t calc_igamma(scalar_t a, scalar_t x) { * result at the boundary * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for * Large Parameter (see DLMF 8.12.3 [igam1]) - * - if x > 1 and x > a, using the substraction from the regularized upper + * - if x > 1 and x > a, using the subtraction from the regularized upper * incomplete gamma * - otherwise, calculate the series from [igam2] eq (4) */ @@ -1730,7 +1730,7 @@ inline C10_HOST_DEVICE T calc_ndtri(T y0) { with the usual checks for overflow etcetera. Performance-wise, it seems to be substantially faster than either - the SLATEC DERFC function [or an erfcx function derived therefrom] + the SLATEC DERFC function [or an erfcx function derived there from] or Cody's CALERF function (from netlib.org/specfun), while retaining near machine precision in accuracy. */ diff --git a/aten/src/ATen/native/Pool.h b/aten/src/ATen/native/Pool.h index 51d19102ad93..7f335de04b90 100644 --- a/aten/src/ATen/native/Pool.h +++ b/aten/src/ATen/native/Pool.h @@ -17,7 +17,7 @@ using max_pool2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& g DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel) DECLARE_DISPATCH(max_pool2d_backward_fn, max_pool2d_backward_kernel) -// averge pooling has same signature for forward and backward +// average pooling has same signature for forward and backward using avg_pool2d_fn = void(*)(const Tensor& output, const Tensor& input, int64_t kW, int64_t kH, int64_t dW, int64_t dH, int64_t padW, int64_t padH, bool count_include_pad, std::optional divisor_override); using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input, int kW, int kH, @@ -26,7 +26,7 @@ using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_kernel) DECLARE_DISPATCH(avg_pool2d_backward_fn, avg_pool2d_backward_kernel) -// averge pooling has same signature for forward and backward +// average pooling has same signature for forward and backward using avg_pool3d_fn = void(*)(const Tensor& output, const Tensor& input, int64_t kW, int64_t kH, int64_t kD, int64_t dW, int64_t dH, int64_t dD, int64_t padW, int64_t padH, int64_t padD, bool count_include_pad, diff --git a/aten/src/ATen/native/QuantizedLinear.cpp b/aten/src/ATen/native/QuantizedLinear.cpp index 037287a06c49..f4fdd395f013 100644 --- a/aten/src/ATen/native/QuantizedLinear.cpp +++ b/aten/src/ATen/native/QuantizedLinear.cpp @@ -25,9 +25,11 @@ #include #ifdef USE_FBGEMM +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include #include #include +C10_DIAGNOSTIC_POP() #endif // USE_FBGEMM namespace caffe2 { @@ -409,7 +411,7 @@ Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) { Tensor fbgemm_linear_fp16_weight_fp32_activation( const Tensor& input, const Tensor& packed_weight, - const Tensor& bias) { + const std::optional& bias) { TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated " "and will be removed in a future PyTorch release.") @@ -430,7 +432,6 @@ Tensor fbgemm_linear_fp16_weight_fp32_activation( TORCH_CHECK(input.size(input.dim() - 1) == packed_weight_fp16.numRows()) TORCH_CHECK(input.dim() >= 2); - TORCH_CHECK(bias.dim() == 1); // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) const int64_t M = size_to_dim_(input.dim() - 1, input.sizes()); @@ -449,7 +450,12 @@ Tensor fbgemm_linear_fp16_weight_fp32_activation( output.data_ptr()); // Add bias term - output.add_(bias); + c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias); + const Tensor& bias_ = *bias_maybe_owned; + if (bias_.defined()) { + TORCH_CHECK(bias_.dim() == 1); + output.add_(bias_); + } return output; } @@ -551,7 +557,7 @@ Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) { Tensor fbgemm_linear_fp16_weight_fp32_activation( const Tensor& input, const Tensor& packed_weight, - const Tensor& bias) { + const std::optional& bias) { TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated " "and will be removed in a future PyTorch release.") diff --git a/aten/src/ATen/native/SegmentReduce.cpp b/aten/src/ATen/native/SegmentReduce.cpp index 2e9df7530758..2b61bcec6a82 100644 --- a/aten/src/ATen/native/SegmentReduce.cpp +++ b/aten/src/ATen/native/SegmentReduce.cpp @@ -480,7 +480,7 @@ REGISTER_ZVECTOR_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets REGISTER_SVE256_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel) // Currently some computation is being duplicated across forward and backward. -// TODO: Cache indices in forward pass to re-use in backward +// TODO: Cache indices in forward pass to reuse in backward Tensor _segment_reduce_backward_kernel( const Tensor& grad, const Tensor& output, diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 67c0af9212bc..408faea1b764 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -475,7 +475,7 @@ static void build_index_op( TensorIteratorBase& iter, const at::native::AdvancedIndex& info, const Tensor& result) { - // 'TensorIterator' needs to own the things comming from 'info', since + // 'TensorIterator' needs to own the things coming from 'info', since // 'info' will be destroyed after the META function. TensorIteratorConfig config; // info.src is a restrided view of result diff --git a/aten/src/ATen/native/TensorAdvancedIndexingUtils.h b/aten/src/ATen/native/TensorAdvancedIndexingUtils.h index 0a200f157d51..bc6c2533eac5 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexingUtils.h +++ b/aten/src/ATen/native/TensorAdvancedIndexingUtils.h @@ -35,7 +35,9 @@ inline std::tuple canDispatchToMaskedFill( auto self_device = self.device(); for (const std::optional& i : indices) { if (!i.has_value() || !(*i).defined()) { - num_ind++; + if (!mask.defined()) { + num_ind++; + } } else { const Tensor& index = *i; if ((index.scalar_type() != kByte && index.scalar_type() != kBool) || diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index 0fba01ee6e4e..7df7745fc507 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -67,7 +67,7 @@ namespace at::native { namespace { // dense_to_sparse_{csr,bsr,csc,bsc} common helpers -// Preparation fo the N-D dense -> sparse compressed conversion. +// Preparation for the N-D dense -> sparse compressed conversion. // The N-D input is converted to 3-D (single batch dim) where we check that the // product of batch dims is nonzero and for each batch the sparse matrix // contained within has the same number of non-zero elements. diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 1aab4b11c963..054cc66cf8eb 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -1367,9 +1367,9 @@ void randperm_cpu(Tensor& result, int64_t n, CPUGeneratorImpl* generator) { for (int64_t i = 0; i < n - 1; i++) { // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand) int64_t z = generator->random() % (n - i); - scalar_t sav = r__data[i * r__stride_0]; + scalar_t save = r__data[i * r__stride_0]; r__data[i * r__stride_0] = r__data[(z + i) * r__stride_0]; - r__data[(z + i) * r__stride_0] = sav; + r__data[(z + i) * r__stride_0] = save; } return; } diff --git a/aten/src/ATen/native/TensorIteratorReduce.cpp b/aten/src/ATen/native/TensorIteratorReduce.cpp index fbd9ff6b2dd7..ce2987eb251a 100644 --- a/aten/src/ATen/native/TensorIteratorReduce.cpp +++ b/aten/src/ATen/native/TensorIteratorReduce.cpp @@ -80,7 +80,7 @@ static void two_pass_reduction(TensorIteratorBase& iter, loop2d_t loop) { } /// Chooses a dimension over which to parallelize. Prefers the outer-most -/// dimension thats larger than the number of available threads. +/// dimension that's larger than the number of available threads. static int find_split_dim(TensorIteratorBase& iter) { int num_threads = at::get_num_threads(); auto shape = iter.shape(); diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 79b253b16a3f..c2d0856c3cd4 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -247,7 +247,7 @@ TORCH_PRECOMPUTE_META_FUNC(cat)(const ITensorListRef& tensors, int64_t dim) { // Checking names before the actual dimensions. auto maybe_outnames = namedinference::compute_cat_outnames(materialized); - TORCH_CHECK( + TORCH_CHECK_VALUE( !materialized.empty(), "torch.cat(): expected a non-empty list of Tensors"); @@ -274,7 +274,7 @@ TORCH_PRECOMPUTE_META_FUNC(cat)(const ITensorListRef& tensors, int64_t dim) { // when computing the actual output dtype and the flags. if (is_out_defined) { // Check for type promotion, if the output tensor is defined. - TORCH_CHECK( + TORCH_CHECK_TYPE( canCast(out_dtype, result.scalar_type()), "torch.cat(): input types can't be cast to the desired output type ", result.scalar_type()); @@ -293,7 +293,7 @@ TORCH_PRECOMPUTE_META_FUNC(cat)(const ITensorListRef& tensors, int64_t dim) { // are compatible, i.e. we can execute `cat` on them. bool found_valid_tensor = valid < materialized.size(); if (found_valid_tensor) { - TORCH_CHECK( + TORCH_CHECK_INDEX( dim <= materialized[valid].get().dim(), "torch.cat(): dimension ", dim, @@ -384,7 +384,7 @@ Tensor& set_storage_cpu_( result.unsafeGetTensorImpl()->set_storage_offset(storage_offset); at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : std::nullopt; - // We can re-use this kernel for the meta device. + // We can reuse this kernel for the meta device. // We just need to make sure we don't actually try to resize the (null) // storage. at::native::resize_impl_cpu_( @@ -505,7 +505,7 @@ Tensor& set_cpu_(Tensor& result) { return result; } -// We can't re-use the cpu kernel here because we don't want to use the cpu +// We can't reuse the cpu kernel here because we don't want to use the cpu // allocator. Tensor& set_meta_(Tensor& result) { caffe2::TypeMeta dtype = result.dtype(); @@ -1904,7 +1904,7 @@ Tensor repeat(const Tensor& self, IntArrayRef repeats) { } Tensor tile_symint(const Tensor& self, SymIntArrayRef reps) { - // If self.size() > len(reps), reps is promoted to self.size() by pre-pending + // If self.size() > len(reps), reps is promoted to self.size() by prepending // 1’s to it to keep the same behaviour as `numpy.tile`. // Thus for a tensor of shape (2, 3, 4, 5), a dims of (2, 2) is treated // as (1, 1, 2, 2). @@ -2428,7 +2428,7 @@ Tensor index_select_sparse_cpu( const auto dim_indices = indices[dim].contiguous(); // If nnz is smaller than size, then either indices[dim] or index gets - // sorted, then this is followed by a binary search to find interesections. + // sorted, then this is followed by a binary search to find intersections. const auto get_selected_indices_small_nnz_large_size = [&]() -> std::tuple { const auto grain_size = at::internal::GRAIN_SIZE; @@ -3934,7 +3934,7 @@ Tensor squeeze_qtensor(const Tensor& self, c10::OptionalIntArrayRef dims) { quantizer->scalar_type()); } // TODO: quantized Tensor support for SymInt needs to be added but basic - // building blocs are missing for now. + // building blocks are missing for now. auto result = make_qtensor( self, C10_AS_INTARRAYREF_SLOW(sizes), diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h index 1d0215fbfc5d..9a122cd7cf05 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h @@ -4,9 +4,11 @@ #include #ifdef USE_FBGEMM +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include #include #include +C10_DIAGNOSTIC_POP() namespace ao::sparse { diff --git a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp index 521a65c7cd94..9450b7eca9b3 100644 --- a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp +++ b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp @@ -14,6 +14,12 @@ namespace at::native { namespace { +// fixes segfaults for GCC >= 12 on some AArch64 cpus https://github.com/pytorch/pytorch/issues/157626 +#if defined(__GNUC__) && __GNUC__ >= 12 && defined(__aarch64__) +#pragma GCC push_options +#pragma GCC optimize ("no-strict-aliasing") +#endif + /** NOTE [ Grid Sample CPU Kernels ] * * Implementation of vectorized grid sample CPU kernels is divided into three @@ -1014,6 +1020,10 @@ struct ApplyGridSample= 12 && defined(__aarch64__) +#pragma GCC pop_options +#endif + // ~~~~~~~~~~~~~~~~~~ grid_sample_2d_grid_slice_iterator ~~~~~~~~~~~~~~~~~~~~~~ // Function to apply a vectorized function on a grid slice tensor (without batch // dimension). diff --git a/aten/src/ATen/native/cpu/utils.h b/aten/src/ATen/native/cpu/utils.h index e1c7e5c60747..827c69629eb3 100644 --- a/aten/src/ATen/native/cpu/utils.h +++ b/aten/src/ATen/native/cpu/utils.h @@ -6,7 +6,9 @@ #include #ifdef USE_FBGEMM +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include +C10_DIAGNOSTIC_POP() #endif namespace at::native { diff --git a/aten/src/ATen/native/cuda/CUDAScalar.cu b/aten/src/ATen/native/cuda/CUDAScalar.cu index 07ada8a0a5f7..0d34bd52f211 100644 --- a/aten/src/ATen/native/cuda/CUDAScalar.cu +++ b/aten/src/ATen/native/cuda/CUDAScalar.cu @@ -11,25 +11,11 @@ #include -#if defined(USE_ROCM) -// TODO(lufang): Tensor.item() on AMD HIP is not synced in the Recsys models. -// This is just a short term workaround. Issue is tracked as FBA-388 on the AMD side. -namespace { - bool use_sync_mode() { - static const bool sync_mode = c10::utils::check_env("HIP_DOUBLE_SYNC_ON_LOCAL_SCALE_DENSE") == true; - return sync_mode; - } -} -#endif - namespace at::native { Scalar _local_scalar_dense_cuda(const Tensor& self) { Scalar r; TORCH_CHECK(self.numel() > 0, "_local_scalar_dense: Empty tensor not supported"); -#if defined(USE_ROCM) - if (!use_sync_mode()){ -#endif AT_DISPATCH_V2( self.scalar_type(), "_local_scalar_dense_cuda", AT_WRAP([&] { // Create pinned memory for the scalar value to avoid implicit @@ -46,15 +32,6 @@ Scalar _local_scalar_dense_cuda(const Tensor& self) { at::cuda::memcpy_and_sync((void *)value.const_data_ptr(), self.const_data_ptr(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream); r = Scalar(*value.const_data_ptr()); }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); -#if defined(USE_ROCM) - } else { - auto cpu_self = self.cpu(); - AT_DISPATCH_V2( - self.scalar_type(), "_local_scalar_dense_hip", AT_WRAP([&] { - r = Scalar(*cpu_self.const_data_ptr()); - }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); - } -#endif return r; } diff --git a/aten/src/ATen/native/cuda/Repeat.cu b/aten/src/ATen/native/cuda/Repeat.cu index aa351baaf9c0..1e2364ae5091 100644 --- a/aten/src/ATen/native/cuda/Repeat.cu +++ b/aten/src/ATen/native/cuda/Repeat.cu @@ -17,7 +17,13 @@ __global__ static void compute_cuda_kernel( index_t* result_ptr, int64_t size, int64_t result_size) { - CUDA_KERNEL_ASSERT(result_size == cumsum_ptr[size - 1]); + if (C10_UNLIKELY((result_size != cumsum_ptr[size - 1]))) { + printf("%s:%d:%s: block: [%d,%d,%d], thread: [%d,%d,%d] " + "Invalid input! In `repeat_interleave`, the `output_size` argument (%ld) must be the same as the sum of the elements in the `repeats` tensor (%ld).\n", + __FILE__, __LINE__, __func__,blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, threadIdx.z, result_size, cumsum_ptr[size - 1 ]); + CUDA_KERNEL_ASSERT(result_size == cumsum_ptr[size - 1]) + } + int64_t idx = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; int64_t stride = (blockDim.x * gridDim.x) / C10_WARP_SIZE; int warp_id = idx / C10_WARP_SIZE; diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index bdb169e26b14..940680eb3682 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -33,7 +33,12 @@ namespace at::native { namespace { constexpr int kCUDANumThreads = 256; +#ifdef USE_ROCM +// C10_WARP_SIZE is not constexpr for host code. +#define kWarpSize C10_WARP_SIZE +#else constexpr unsigned int kWarpSize = C10_WARP_SIZE; +#endif constexpr int vec_size = 4; //we could make it dependent on dtype, but that would lead to different results between float and low-p types // aligned vector generates vectorized load/store on CUDA (copy-pasted from MemoryAccess.cuh) @@ -50,7 +55,7 @@ bool can_vectorize(const T * ptr, int alignment) { }; -template +template __global__ void RowwiseMomentsCUDAKernel( int64_t N, T_ACC eps, @@ -84,12 +89,17 @@ __global__ void RowwiseMomentsCUDAKernel( T_ACC m1; T_ACC m2; thrust::tie(m2, m1) = welford_op.project(val); - mean[i] = m1; - rstd[i] = c10::cuda::compat::rsqrt(m2 + eps); + if constexpr (!rms_norm){ + mean[i] = m1; + rstd[i] = c10::cuda::compat::rsqrt(m2 + eps); + } else { + rstd[i] = c10::cuda::compat::rsqrt(m2 + m1 * m1 + eps); + } + } } -template +template __global__ void LayerNormForwardCUDAKernel( int64_t N, const T* X, @@ -103,11 +113,15 @@ __global__ void LayerNormForwardCUDAKernel( const int64_t index = i * N + j; const T_ACC gamma_v = gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); - const T_ACC beta_v = - beta == nullptr ? T_ACC(0) : static_cast(beta[j]); - Y[index] = (static_cast(X[index]) - static_cast(mean[i])) * - static_cast(rstd[i]) * gamma_v + - beta_v; + if constexpr (!rms_norm){ + const T_ACC beta_v = + beta == nullptr ? T_ACC(0) : static_cast(beta[j]); + Y[index] = (static_cast(X[index]) - static_cast(mean[i])) * + static_cast(rstd[i]) * gamma_v + + beta_v; + } else { + Y[index] = (static_cast(X[index])) * static_cast(rstd[i]) * gamma_v; + } } } @@ -119,40 +133,48 @@ struct WelfordDataLN{ C10_HOST_DEVICE WelfordDataLN(float mean, float sigma2, float count): mean(mean), sigma2(sigma2), count(count) {} }; -template __device__ +template __device__ WelfordDataLN cuWelfordOnlineSum( const U val, const WelfordDataLN& curr_sum) { - U delta = val - curr_sum.mean; - U new_count = curr_sum.count + 1.f; - U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster - return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; + if constexpr (!rms_norm){ + U delta = val - curr_sum.mean; + U new_count = curr_sum.count + 1.f; + U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster + return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; + } else{ + return {0.f, curr_sum.sigma2 + val * val, 0}; + } } -__device__ +template __device__ WelfordDataLN cuWelfordCombine( const WelfordDataLN dataB, const WelfordDataLN dataA ) { - using U = decltype(dataB.count); - U delta = dataB.mean - dataA.mean; - U count = dataA.count + dataB.count; - U mean, sigma2; - if (count > decltype(dataB.count){0}) { - auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division - auto nA = dataA.count * coef; - auto nB = dataB.count * coef; - mean = nA*dataA.mean + nB*dataB.mean; - sigma2 = dataA.sigma2 + dataB.sigma2 + delta * delta * dataA.count * nB; + if constexpr (!rms_norm){ + using U = decltype(dataB.count); + U delta = dataB.mean - dataA.mean; + U count = dataA.count + dataB.count; + U mean, sigma2; + if (count > decltype(dataB.count){0}) { + auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division + auto nA = dataA.count * coef; + auto nB = dataB.count * coef; + mean = nA*dataA.mean + nB*dataB.mean; + sigma2 = dataA.sigma2 + dataB.sigma2 + delta * delta * dataA.count * nB; + } else { + mean = U(0); + sigma2 = U(0); + } + return {mean, sigma2, count}; } else { - mean = U(0); - sigma2 = U(0); + return {0.f, dataB.sigma2 + dataA.sigma2, 0}; } - return {mean, sigma2, count}; } -template +template __device__ WelfordDataLN compute_stats( const T* __restrict__ X, const int N, @@ -171,14 +193,13 @@ __device__ WelfordDataLN compute_stats( vec_t data = X_vec[i]; #pragma unroll for (int ii=0; ii < vec_size; ii++){ - wd = cuWelfordOnlineSum(static_cast(data.val[ii]), wd); + wd = cuWelfordOnlineSum(static_cast(data.val[ii]), wd); } } // intra-warp reduction for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) { - WelfordDataLN wdB{WARP_SHFL_DOWN(wd.mean, offset), - WARP_SHFL_DOWN(wd.sigma2, offset), WARP_SHFL_DOWN(wd.count, offset)}; - wd = cuWelfordCombine(wd, wdB); + WelfordDataLN wdB{WARP_SHFL_DOWN(wd.mean, offset), WARP_SHFL_DOWN(wd.sigma2, offset), WARP_SHFL_DOWN(wd.count, offset)}; + wd = cuWelfordCombine(wd, wdB); } // threadIdx.x == 0 has correct values for each warp // inter-warp reductions @@ -199,7 +220,7 @@ __device__ WelfordDataLN compute_stats( WelfordDataLN wdB{meansigmabuf[2*threadIdx.y], meansigmabuf[2*threadIdx.y+1], countbuf[threadIdx.y]}; - wd = cuWelfordCombine(wd, wdB); + wd = cuWelfordCombine(wd, wdB); } __syncthreads(); } @@ -216,7 +237,7 @@ __device__ WelfordDataLN compute_stats( } -template , int> = 0> __device__ __inline__ void vectorized_layer_norm_kernel_impl( const int N, @@ -231,7 +252,7 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( //as one thread would have to write 3 consecutive floats auto i1 = blockIdx.x; const T * block_row = X + i1 * N; - WelfordDataLN wd = compute_stats(block_row, N, s_data); + WelfordDataLN wd = compute_stats(block_row, N, s_data); using vec_t = aligned_vector; const vec_t * X_vec = reinterpret_cast(block_row); @@ -254,34 +275,48 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( if (gamma_vec != nullptr && beta_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)) - + static_cast(beta_vec[i].val[ii]); + if constexpr (!rms_norm){ + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + + static_cast(beta_vec[i].val[ii]); + } else { + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * static_cast(data.val[ii])); + } } } else if (gamma_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)); + if constexpr (!rms_norm){ + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * (static_cast(data.val[ii]) - wd.mean)); + } else { + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * (rstd_val * static_cast(data.val[ii])); + } } } else if (beta_vec != nullptr) { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - out.val[ii] = (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + static_cast(beta_vec[i].val[ii]); + out.val[ii] = (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + static_cast(beta_vec[i].val[ii]); } } else { #pragma unroll for (int ii=0; ii < vec_size; ii++){ - out.val[ii] = rstd_val * (static_cast(data.val[ii]) - wd.mean); + if constexpr (!rms_norm){ + out.val[ii] = rstd_val * (static_cast(data.val[ii]) - wd.mean); + } else { + out.val[ii] = rstd_val * static_cast(data.val[ii]); + } } } Y_vec[i] = out; } if (thrx == 0) { - mean[i1] = wd.mean; + if constexpr (!rms_norm){ + mean[i1] = wd.mean; + } rstd[i1] = rstd_val; } } -template , int> = 0> __device__ __inline__ void vectorized_layer_norm_kernel_impl( const int /*N*/, @@ -296,7 +331,7 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl( } //to avoid windows SFINAE errors -template +template __global__ void vectorized_layer_norm_kernel( const int N, T_ACC eps, @@ -306,11 +341,11 @@ __global__ void vectorized_layer_norm_kernel( T_ACC* mean, T_ACC* rstd, T* Y){ - vectorized_layer_norm_kernel_impl(N, eps, X, gamma, beta, mean, rstd, Y); + vectorized_layer_norm_kernel_impl(N, eps, X, gamma, beta, mean, rstd, Y); } -template +template __device__ __inline__ void compute_gI( const T* __restrict__ dY, const T* __restrict__ X, @@ -321,7 +356,10 @@ __device__ __inline__ void compute_gI( const int N, T_ACC * buf){ const auto i1 = blockIdx.x; - const T_ACC mean_val = mean[i1]; + T_ACC mean_val = 0; + if constexpr (!rms_norm){ + mean_val = mean[i1]; + } const T_ACC rstd_val = rstd[i1]; T_ACC stats_x1{0}, stats_x2{0}; constexpr int unroll = 4; @@ -337,26 +375,39 @@ __device__ __inline__ void compute_gI( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l+k]) : T_ACC(1); const auto c_h = static_cast(X_i[l+k]); const auto c_loss = static_cast(dY_i[l+k]); - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + if constexpr (!rms_norm){ + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + } else { + stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; + } } } for (; l < N; l ++) { const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); const auto c_h = static_cast(X_i[l]); const auto c_loss = static_cast(dY_i[l]); - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + if constexpr (!rms_norm){ + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + } else { + stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; + } + } + if constexpr (!rms_norm){ + stats_x1 = cuda_utils::BlockReduceSum(stats_x1, buf); } - - stats_x1 = cuda_utils::BlockReduceSum(stats_x1, buf); stats_x2 = cuda_utils::BlockReduceSum(stats_x2, buf); if (threadIdx.x == 0) { - buf[0] = stats_x1; + if constexpr (!rms_norm){ + buf[0] = stats_x1; + } buf[1] = stats_x2; } __syncthreads(); - stats_x1 = buf[0]; + if constexpr (!rms_norm){ + stats_x1 = buf[0]; + } stats_x2 = buf[1]; T_ACC fH = N; T_ACC term1 = (T_ACC(1) / fH) * rstd_val; @@ -367,15 +418,20 @@ __device__ __inline__ void compute_gI( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); T_ACC f_grad_input = fH * gamma_val * dy; - f_grad_input -= (x - mean_val) * rstd_val * stats_x2; - f_grad_input -= stats_x1; + if constexpr (!rms_norm){ + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; + } else { + f_grad_input -= (x) * rstd_val * stats_x2; + } + f_grad_input *= term1; dX_i[l] = f_grad_input; } } -template +template __global__ void layer_norm_grad_input_kernel( const T* __restrict__ dY, const T* __restrict__ X, @@ -387,7 +443,7 @@ __global__ void layer_norm_grad_input_kernel( alignas(sizeof(double)) extern __shared__ char s_data1[]; T_ACC * buf = reinterpret_cast(&s_data1); - compute_gI(dY, X, mean, rstd, gamma, dX, N, buf); + compute_gI(dY, X, mean, rstd, gamma, dX, N, buf); } @@ -396,7 +452,7 @@ __global__ void layer_norm_grad_input_kernel( // faster measured at PT operator level, with cases seeing a 2X speedup (where N >> M). // There are no noticeable regressions on the rest of the sizes. -template +template __global__ void layer_norm_grad_input_kernel_vectorized( const T* __restrict__ dY, const T* __restrict__ X, @@ -409,7 +465,10 @@ __global__ void layer_norm_grad_input_kernel_vectorized( T_ACC* reduce_buf = reinterpret_cast(&shared_data); const auto bIdx = blockIdx.x; - const T_ACC mean_val = mean[bIdx]; + T_ACC mean_val = 0; + if constexpr (!rms_norm){ + mean_val = mean[bIdx]; + } const T_ACC rstd_val = rstd[bIdx]; const T* X_i = X + bIdx * N; const T* dY_i = dY + bIdx * N; @@ -441,8 +500,12 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = static_cast(gamma_vec_reg.val[k]); const auto c_h = static_cast(X_i_vec_reg.val[k]); const auto c_loss = static_cast(dY_i_vec_reg.val[k]); - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + if constexpr (!rms_norm){ + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + } else { + stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; + } } } @@ -451,19 +514,29 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); const auto c_h = static_cast(X_i[l]); const auto c_loss = static_cast(dY_i[l]); - stats_x1 += c_loss * gamma_val; - stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + if constexpr (!rms_norm){ + stats_x1 += c_loss * gamma_val; + stats_x2 += c_loss * gamma_val * (c_h - mean_val) * rstd_val; + } else{ + stats_x2 += c_loss * gamma_val * (c_h) * rstd_val; + } } // Reduction in Shared Memory - stats_x1 = cuda_utils::BlockReduceSum(stats_x1, reduce_buf); + if constexpr (!rms_norm){ + stats_x1 = cuda_utils::BlockReduceSum(stats_x1, reduce_buf); + } stats_x2 = cuda_utils::BlockReduceSum(stats_x2, reduce_buf); if (threadIdx.x == 0) { - reduce_buf[0] = stats_x1; + if constexpr (!rms_norm){ + reduce_buf[0] = stats_x1; + } reduce_buf[1] = stats_x2; } __syncthreads(); - stats_x1 = reduce_buf[0]; + if constexpr (!rms_norm){ + stats_x1 = reduce_buf[0]; + } stats_x2 = reduce_buf[1]; T_ACC fH = N; @@ -485,8 +558,12 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto dy = static_cast(dY_i_vec_reg.val[k]); T_ACC f_grad_input = fH * gamma_val * dy; - f_grad_input -= (x - mean_val) * rstd_val * stats_x2; - f_grad_input -= stats_x1; + if constexpr (!rms_norm){ + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; + } else { + f_grad_input -= (x) * rstd_val * stats_x2; + } f_grad_input *= term1; dX_i_vec_reg.val[k] = f_grad_input; } @@ -501,15 +578,19 @@ __global__ void layer_norm_grad_input_kernel_vectorized( const auto gamma_val = (gamma != nullptr) ? static_cast(gamma[l]) : T_ACC(1); T_ACC f_grad_input = fH * gamma_val * dy; - f_grad_input -= (x - mean_val) * rstd_val * stats_x2; - f_grad_input -= stats_x1; + if constexpr (!rms_norm){ + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; + } else { + f_grad_input -= (x) * rstd_val * stats_x2; + } f_grad_input *= term1; dX_i[l] = f_grad_input; } } -template +template __global__ void GammaBetaBackwardSimpleCUDAKernel( int64_t M, int64_t N, @@ -525,17 +606,25 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel( T_ACC sum2 = 0; for (int64_t i = 0; i < M; ++i) { const int64_t index = i * N + j; - sum1 += dg == nullptr ? T_ACC(0) - : static_cast(dY[index]) * - (static_cast(X[index]) - static_cast(mean[i])) * - static_cast(rstd[i]); - sum2 += db == nullptr ? T_ACC(0) : static_cast(dY[index]); + if constexpr (!rms_norm){ + sum1 += dg == nullptr ? T_ACC(0) + : static_cast(dY[index]) * + (static_cast(X[index]) - static_cast(mean[i])) * + static_cast(rstd[i]); + sum2 += db == nullptr ? T_ACC(0) : static_cast(dY[index]); + } else { + sum1 += dg == nullptr ? T_ACC(0) + : static_cast(dY[index]) * + (static_cast(X[index])) * static_cast(rstd[i]); + } } if (dg != nullptr) { dg[j] = sum1; } if (db != nullptr) { - db[j] = sum2; + if constexpr (!rms_norm){ + db[j] = sum2; + } } } } @@ -545,7 +634,8 @@ unsigned int block_dim_x, unsigned int block_dim_y, unsigned int rows_per_block_y, bool check_x, -bool check_y> +bool check_y, +bool rms_norm> __device__ __forceinline__ void @@ -569,7 +659,9 @@ blockReduceGammaBetaBackwardsHelper( int64_t mean_index = M_start + threadIdx.y * rows_per_thread_y; T_ACC warp_mean = 0, warp_rstd = 0; if (lane_id < rows_per_thread_y && mean_index + lane_id < M) { - warp_mean = mean[mean_index + lane_id]; + if constexpr (!rms_norm){ + warp_mean = mean[mean_index + lane_id]; + } warp_rstd = rstd[mean_index + lane_id]; } // We do a WARP_SYNC() here because we use WARP_SHFL below to access @@ -596,10 +688,14 @@ blockReduceGammaBetaBackwardsHelper( #pragma unroll for (int i = 0; i < rows_per_thread_y; ++i) { - T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, kWarpSize); - dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg; - db_sum += dY_regs[i]; + if constexpr (!rms_norm){ + T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); + dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg; + db_sum += dY_regs[i]; + } else{ + dg_sum += dY_regs[i] * (X_regs[i]) * rstd_reg; + } } } @@ -608,7 +704,8 @@ unsigned int block_dim_x, unsigned int block_dim_y, unsigned int rows_per_block_y, bool check_x, -bool check_y> +bool check_y, +bool rms_norm> __device__ __forceinline__ void @@ -629,10 +726,10 @@ blockReduceGammaBetaBackwardsWithChecks( M_start += rows_per_block_y * gridDim.y) { int64_t M_end = M_start + rows_per_block_y - 1; if (!check_y || M_end < M) { - blockReduceGammaBetaBackwardsHelper + blockReduceGammaBetaBackwardsHelper (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { - blockReduceGammaBetaBackwardsHelper + blockReduceGammaBetaBackwardsHelper (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } } @@ -654,7 +751,8 @@ template __global__ void @@ -679,7 +777,7 @@ __launch_bounds__(block_dim_x * block_dim_y) // When N and M align perfectly with block_dim_x and block_dim_y, we // can skip boundary condition checks that waste instruction issue slots. blockReduceGammaBetaBackwardsWithChecks - + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { // In the general case we need to check boundary conditions in the M @@ -687,11 +785,11 @@ __launch_bounds__(block_dim_x * block_dim_y) // for the inner blocks. So try to avoid those checks when possible. if (blockIdx.x * block_dim_x + block_dim_x - 1 < N) { blockReduceGammaBetaBackwardsWithChecks - + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } else { blockReduceGammaBetaBackwardsWithChecks - + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); } } @@ -706,7 +804,7 @@ __launch_bounds__(block_dim_x * block_dim_y) if (dg) { dg[thread_y * N + thread_x] = dg_sum; } - if (db) { + if (db && !rms_norm) { db[thread_y * N + thread_x] = db_sum; } } @@ -752,7 +850,7 @@ __launch_bounds__(block_dim_x * block_dim_y) if (dg) { dg[out_index] = reg_dg; } - if (db) { + if (db && !rms_norm) { db[out_index] = reg_db; } } @@ -763,7 +861,8 @@ __launch_bounds__(block_dim_x * block_dim_y) template +bool partial_reduction, +bool rms_norm> void LaunchAndCheckGammaBetaBackwardKernel( bool aligned_grid, dim3 blocks, @@ -779,7 +878,7 @@ void LaunchAndCheckGammaBetaBackwardKernel( T* dgamma_data, T* dbeta_data) { if (aligned_grid) { - GammaBetaBackwardCUDAKernelTemplate + GammaBetaBackwardCUDAKernelTemplate <<>>( M, N, @@ -790,7 +889,7 @@ if (aligned_grid) { dgamma_data, dbeta_data); } else { - GammaBetaBackwardCUDAKernelTemplate + GammaBetaBackwardCUDAKernelTemplate <<>>( M, N, @@ -806,7 +905,7 @@ if (aligned_grid) { template +int rows_per_block_y, bool rms_norm> void ConfigureAndLaunchGammaBetaBackwardKernel( const T* dY_data, const T* X_data, @@ -829,16 +928,16 @@ void ConfigureAndLaunchGammaBetaBackwardKernel( if (blocks.y == 1 && threads.y == 1) { // Optimization: since there is just one thread doing all the summation, we don't need a reduction // across threads. So we set partial_reduction to true. - LaunchAndCheckGammaBetaBackwardKernel( + LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); } else { - LaunchAndCheckGammaBetaBackwardKernel( + LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); } } -template +template void LaunchGammaBetaBackwardCUDAKernel( const T* dY_data, const T* X_data, @@ -876,19 +975,21 @@ void LaunchGammaBetaBackwardCUDAKernel( dgamma_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); dgamma_blocks_ptr = dgamma_blocks.data_ptr(); } - if (dbeta->defined()) { + if (dbeta->defined() && !rms_norm) { auto options = dbeta->options(); dbeta_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); dbeta_blocks_ptr = dbeta_blocks.data_ptr(); } - LaunchAndCheckGammaBetaBackwardKernel( + LaunchAndCheckGammaBetaBackwardKernel( aligned_grid, blocks, threads, 0, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_blocks_ptr, dbeta_blocks_ptr); if (dgamma_blocks.defined()) { *dgamma = dgamma_blocks.sum(0); } - if (dbeta_blocks.defined()) { - *dbeta = dbeta_blocks.sum(0); + if constexpr (!rms_norm){ + if (dbeta_blocks.defined()) { + *dbeta = dbeta_blocks.sum(0); + } } } else { // We are in the normal case where M is not that large. @@ -896,18 +997,18 @@ void LaunchGammaBetaBackwardCUDAKernel( // For small M it is faster to have a smaller tile, otherwise we could have idle threads. // For larger M we use a bigger tile size. if (M < 64) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else if (M < 128) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else if (M < 256) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } } } -template +template void launch_vectorized_layer_norm_kernel( int N, int64_t M, @@ -936,7 +1037,7 @@ void launch_vectorized_layer_norm_kernel( TORCH_INTERNAL_ASSERT_DEBUG_ONLY(threads.y % 2 == 0 || threads.y == 1); int nshared = threads.y > 1 ? threads.y * 3/2 *sizeof(T_ACC) : 0; - vectorized_layer_norm_kernel<<>>(N, eps, X_data, + vectorized_layer_norm_kernel<<>>(N, eps, X_data, gamma_data, beta_data, mean_data, rstd_data, Y_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -958,7 +1059,7 @@ void launch_vectorized_layer_norm_kernel( blocks.x = (remaining > blocks.x) ? blocks.x : remaining; - vectorized_layer_norm_kernel<<>>(N, eps, X_data2, + vectorized_layer_norm_kernel<<>>(N, eps, X_data2, gamma_data, beta_data, mean_data2, rstd_data2, Y_data2); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -968,7 +1069,7 @@ void launch_vectorized_layer_norm_kernel( } -template +template void LayerNormKernelImplInternal( const Tensor& X, const Tensor& gamma, @@ -987,7 +1088,7 @@ void LayerNormKernelImplInternal( const T* gamma_data = gamma.defined() ? gamma.const_data_ptr() : nullptr; const T* beta_data = beta.defined() ? beta.const_data_ptr() : nullptr; T* Y_data = Y->data_ptr(); - T_ACC* mean_data = mean->data_ptr(); + T_ACC* mean_data = !rms_norm ? mean->data_ptr() : nullptr; T_ACC* rstd_data = rstd->data_ptr(); // check if can take fast path - all tensors are properly aligned, N is less than 2^24 (to use float count), @@ -1002,14 +1103,14 @@ void LayerNormKernelImplInternal( if ((std::is_same_v || std::is_same_v || std::is_same_v) && N <= static_cast(1ULL << std::numeric_limits::digits) && N % num_vec_elems == 0 && can_vec_X && can_vec_Y && can_vec_gamma && can_vec_beta) { - launch_vectorized_layer_norm_kernel(static_cast(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data); + launch_vectorized_layer_norm_kernel(static_cast(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data); } else { cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); - RowwiseMomentsCUDAKernel + RowwiseMomentsCUDAKernel <<>>( N, eps, X_data, mean_data, rstd_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); - LayerNormForwardCUDAKernel<<>>( + LayerNormForwardCUDAKernel<<>>( N, X_data, mean_data, rstd_data, gamma_data, beta_data, Y_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1037,7 +1138,29 @@ void LayerNormKernelImpl( }); } -template __device__ +void RmsNormKernelImpl( + const Tensor& X, + const Tensor& gamma, + int64_t M, + int64_t N, + double eps, + Tensor* Y, + Tensor* rstd) { +AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + X.scalar_type(), + "LayerNormKernelImpl", + [&]() { + using acc_t = acc_type; + // rms_norm = true + LayerNormKernelImplInternal( + // pass in at::Tensor() for gamma and nullptr for mean, it won't be accessed with rms_norm = True + X, gamma, at::Tensor(), M, N, static_cast(eps), Y, nullptr, rstd); + }); +} + +template __device__ void cuLoadWriteStridedInputs( const int i1_block, const int thr_load_row_off, @@ -1055,7 +1178,10 @@ void cuLoadWriteStridedInputs( { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { - T_ACC curr_mean = mean[i1]; + T_ACC curr_mean = 0; + if constexpr (!rms_norm){ + curr_mean = mean[i1]; + } T_ACC curr_rstd = rstd[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; @@ -1080,7 +1206,7 @@ void cuLoadWriteStridedInputs( } } -template __device__ +template __device__ void cuLoadAddStridedInputs( const int i1_block, const int thr_load_row_off, @@ -1098,7 +1224,11 @@ void cuLoadAddStridedInputs( { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { - T_ACC curr_mean = mean[i1]; + + T_ACC curr_mean = 0; + if constexpr (!rms_norm){ + curr_mean = mean[i1]; + } T_ACC curr_rstd = rstd[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; @@ -1114,7 +1244,7 @@ void cuLoadAddStridedInputs( } } -template __global__ +template __global__ void cuComputePartGradGammaBeta( const T* __restrict__ dout, const T* __restrict__ input, @@ -1140,9 +1270,9 @@ void cuComputePartGradGammaBeta( T_ACC* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; // compute partial sums from strided inputs // do this to increase number of loads in flight - cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); + cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { - cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); + cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); } __syncthreads(); // inter-warp reductions @@ -1181,7 +1311,7 @@ void cuComputePartGradGammaBeta( } } -template __global__ +template __global__ void cuComputeGradGammaBeta( const T_ACC* part_grad_gamma, const T_ACC* part_grad_beta, @@ -1206,7 +1336,9 @@ void cuComputeGradGammaBeta( if (i2 < N) { for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { sum_gamma += part_grad_gamma_ptr[warp_offset*N]; - sum_beta += part_grad_beta_ptr[warp_offset*N]; + if constexpr (!rms_norm){ + sum_beta += part_grad_beta_ptr[warp_offset*N]; + } } } @@ -1224,7 +1356,9 @@ void cuComputeGradGammaBeta( if (threadIdx.y < offset) { const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; sum_gamma += buf[read_idx]; - sum_beta += buf[read_idx+nbsize3]; + if constexpr (!rms_norm){ + sum_beta += buf[read_idx+nbsize3]; + } } __syncthreads(); } @@ -1235,12 +1369,14 @@ void cuComputeGradGammaBeta( grad_gamma[i2] = sum_gamma; } if (grad_beta) { - grad_beta[i2] = sum_beta; + if constexpr (!rms_norm){ + grad_beta[i2] = sum_beta; + } } } } -template __global__ +template __global__ void cuComputeGradInput( const T* __restrict__ dout, const T* __restrict__ input, @@ -1254,7 +1390,10 @@ void cuComputeGradInput( for (int i1=blockIdx.y; i1 < M; i1 += gridDim.y) { T_ACC sum_loss1 = T_ACC(0); T_ACC sum_loss2 = T_ACC(0); - T_ACC c_mean = mean[i1]; + T_ACC c_mean = 0; + if constexpr (!rms_norm){ + c_mean = mean[i1]; + } const T_ACC c_rstd = rstd[i1]; const T* k_input = input + i1*N; const T* k_dout = dout + i1*N; @@ -1267,21 +1406,31 @@ void cuComputeGradInput( const T_ACC gamma_idx = static_cast((idx((idx((idx((idx((idx 0; mask /= 2) { - sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + if constexpr (!rms_norm){ + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + } sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); } // inter-warp reductions @@ -1292,25 +1441,33 @@ void cuComputeGradInput( // upper half of warps write to shared if (threadIdx.y >= offset && threadIdx.y < 2*offset) { const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; - buf[2*wrt_i] = sum_loss1; + if constexpr (!rms_norm){ + buf[2*wrt_i] = sum_loss1; + } buf[2*wrt_i+1] = sum_loss2; } __syncthreads(); // lower half merges if (threadIdx.y < offset) { const int read_i = threadIdx.y * blockDim.x + threadIdx.x; - sum_loss1 += buf[2*read_i]; + if constexpr (!rms_norm){ + sum_loss1 += buf[2*read_i]; + } sum_loss2 += buf[2*read_i+1]; } __syncthreads(); } if (threadIdx.y == 0) { - buf[2*threadIdx.x] = sum_loss1; + if constexpr (!rms_norm){ + buf[2*threadIdx.x] = sum_loss1; + } buf[2*threadIdx.x+1] = sum_loss2; } __syncthreads(); if (threadIdx.y !=0) { - sum_loss1 = buf[2*threadIdx.x]; + if constexpr (!rms_norm){ + sum_loss1 = buf[2*threadIdx.x]; + } sum_loss2 = buf[2*threadIdx.x+1]; } } @@ -1323,8 +1480,12 @@ void cuComputeGradInput( const T_ACC c_h = static_cast(k_input[l]); const T_ACC c_loss = static_cast(k_dout[l]); T_ACC f_grad_input = fH * c_loss * gamma[l]; - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; + if constexpr (!rms_norm){ + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; + } else { + f_grad_input -= (c_h) * c_rstd * sum_loss2; + } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } @@ -1333,8 +1494,12 @@ void cuComputeGradInput( const T_ACC c_h = static_cast(k_input[l]); const T_ACC c_loss = static_cast(k_dout[l]); T_ACC f_grad_input = fH * c_loss; - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; + if constexpr (!rms_norm){ + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_rstd * sum_loss2; + } else { + f_grad_input -= (c_h) * c_rstd * sum_loss2; + } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } @@ -1344,7 +1509,7 @@ void cuComputeGradInput( } } -template +template void LayerNormBackwardKernelImplInternal( const Tensor& dY, const Tensor& X, @@ -1358,7 +1523,9 @@ void LayerNormBackwardKernelImplInternal( Tensor* dbeta) { using T_ACC = acc_type; TORCH_CHECK(dY.numel() == M * N); - TORCH_CHECK(mean.numel() == M); + if constexpr (!rms_norm){ + TORCH_CHECK(mean.numel() == M); + } TORCH_CHECK(rstd.numel() == M); TORCH_CHECK(M <= at::cuda::getCurrentDeviceProperties()->maxGridSize[0], "M should be less than maximum CUDA grid size, \ file a support request to support bigger batches"); @@ -1384,7 +1551,7 @@ void LayerNormBackwardKernelImplInternal( threads1.y > 1 ? threads1.y*threads1.x*sizeof(T_ACC) : 0; - cuComputeGradInput<<>>( + cuComputeGradInput<<>>( dY_data, X_data, M, N, @@ -1396,7 +1563,7 @@ void LayerNormBackwardKernelImplInternal( } else { const dim3 blocks(M); int nshared = (num_threads()/warp_size) * sizeof(T_ACC); - layer_norm_grad_input_kernel<<>>(dY_data, + layer_norm_grad_input_kernel<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1410,13 +1577,12 @@ void LayerNormBackwardKernelImplInternal( const unsigned int alignment = sizeof(T) * vec_size; bool bAlignedBuffers = can_vectorize(dY_data, alignment) && can_vectorize(X_data, alignment) && can_vectorize(gamma_data, alignment) && can_vectorize(dX_data, alignment); - if (bAlignedBuffers && bTargetDataTypes && bVectorSizeMultiple) { - layer_norm_grad_input_kernel_vectorized<<>>(dY_data, + layer_norm_grad_input_kernel_vectorized<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { - layer_norm_grad_input_kernel<<>>(dY_data, + layer_norm_grad_input_kernel<<>>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, N); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -1432,7 +1598,7 @@ void LayerNormBackwardKernelImplInternal( if (M < 128) { // For small batch size, do colwise reduce directly. const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads; - GammaBetaBackwardSimpleCUDAKernel + GammaBetaBackwardSimpleCUDAKernel <<>>( M, N, @@ -1456,7 +1622,7 @@ void LayerNormBackwardKernelImplInternal( Tensor part_grad_gamma = at::empty({part_size,N}, gamma.options().dtype(part_grad_dtype)); Tensor part_grad_beta = at::native::empty_like(part_grad_gamma); - cuComputePartGradGammaBeta<<>>( + cuComputePartGradGammaBeta<<>>( dY_data, X_data, M,N, @@ -1470,7 +1636,7 @@ void LayerNormBackwardKernelImplInternal( const dim3 blocks3((N + threads3.x - 1) / threads3.x, 1, 1); const int nshared3 = threads3.x * threads3.y * sizeof(T_ACC); - cuComputeGradGammaBeta<<>>( + cuComputeGradGammaBeta<<>>( part_grad_gamma.template data_ptr(), part_grad_beta.template data_ptr(), part_size, @@ -1480,7 +1646,7 @@ void LayerNormBackwardKernelImplInternal( C10_CUDA_KERNEL_LAUNCH_CHECK(); } #else - LaunchGammaBetaBackwardCUDAKernel( + LaunchGammaBetaBackwardCUDAKernel( dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); #endif } @@ -1508,8 +1674,29 @@ void LayerNormBackwardKernelImpl( }); } +void RMSNormBackwardKernelImpl( + const Tensor& dY, + const Tensor& X, + const Tensor& rstd, + const Tensor& gamma, + int64_t M, + int64_t N, + Tensor* dX, + Tensor* dgamma) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + X.scalar_type(), + "LayerNormBackwardKernelImpl", + [&]() { + LayerNormBackwardKernelImplInternal( + dY.contiguous(), X, rstd, rstd, gamma, M, N, dX, dgamma, dgamma); + }); +} + } // namespace + std::tuple layer_norm_cuda( const Tensor& input, IntArrayRef normalized_shape, @@ -1638,6 +1825,113 @@ std::tuple layer_norm_backward_cuda( return std::make_tuple(std::move(dX), std::move(dgamma), std::move(dbeta)); } +/* RMSNorm is implemented by reusing layer_norm's kernels */ +std::tuple _fused_rms_norm_cuda( + const Tensor& input, + IntArrayRef normalized_shape, + const std::optional& weight_opt /* optional */, + std::optional eps){ + + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, weight); + auto M = M_N.first; + auto N = M_N.second; + auto X = input.expect_contiguous(); + auto gamma = weight.expect_contiguous(); + + auto acc_type = at::toAccumulateType(input.scalar_type(), /*is_cuda=*/true); + double eps_val; + if (acc_type == at::ScalarType::Float) { + eps_val = eps.value_or(std::numeric_limits::epsilon()); + } else { + eps_val = eps.value_or(std::numeric_limits::epsilon()); + } + + Tensor Y = at::native::empty_like( + *X, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + Tensor rstd = at::empty({M}, X->options().dtype(acc_type)); + + if (M > 0) { + RmsNormKernelImpl(*X, *gamma, M, N, eps_val, &Y, &rstd); + } + + const auto input_shape = input.sizes(); + const size_t axis = input.dim() - normalized_shape.size(); + + std::vector stat_shape; + for (const auto idx: c10::irange(axis)) { + stat_shape.push_back(input_shape[idx]); + } + for ([[maybe_unused]] const auto idx : c10::irange(axis, input.dim())) { + stat_shape.push_back(1); + } + + rstd = rstd.view(stat_shape); + + return std::make_tuple(std::move(Y), std::move(rstd)); +} + + +std::tuple _fused_rms_norm_backward_cuda( + const Tensor& dY, + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& rstd, + const std::optional& weight_opt /* optional */, + std::array grad_input_mask) { + + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + + auto M_N = _check_layer_norm_inputs(input, normalized_shape, weight, weight); + auto M = M_N.first; + auto N = M_N.second; + auto X = input.expect_contiguous(); + auto gamma = weight.expect_contiguous(); + + Tensor dX; + Tensor dgamma; + if (grad_input_mask[0]) { + dX = at::native::empty_like( + *X, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + if (grad_input_mask[1]) { + dgamma = M > 0 ? at::native::empty_like( + *gamma, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT) + : at::native::zeros_like( + *gamma, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + } + + if (M > 0 && N > 0) { + RMSNormBackwardKernelImpl( + dY, *X, rstd, *gamma, M, N, &dX, &dgamma); + } + return std::make_tuple(std::move(dX), std::move(dgamma)); +} + REGISTER_DISPATCH(LayerNormKernel, &LayerNormKernelImpl) REGISTER_DISPATCH(LayerNormBackwardKernel, &LayerNormBackwardKernelImpl) diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index 4f9e612e8e75..48119a6a3b4c 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -84,37 +84,6 @@ void run_cudnn_SDP_bprop( false, "PyTorch was not compiled with cuDNN Flash Attention enabled!"); } -void run_cudnn_SDP_bprop_nestedtensor( - int64_t b, - int64_t h_q, - int64_t h_k, - int64_t h_v, - int64_t s_q, - int64_t s_kv, - int64_t d_qk, - int64_t d_v, - - float scaling_factor, - bool is_causal, - float dropout_probability, - const Tensor& cum_seqlen_q, - const Tensor& cum_seqlen_kv, - const Tensor& q, - const Tensor& k, - const Tensor& v, - const std::optional& attn_bias, - const Tensor& o, - const Tensor& dO, - const Tensor& softmaxstats, - Tensor& dQ, - Tensor& dK, - Tensor& dV, - const Tensor& dropoutseed, - const Tensor& dropoutoffset) { - TORCH_CHECK( - false, "PyTorch was not compiled with cuDNN Flash Attention enabled!"); -} - } // namespace native } // namespace at @@ -142,6 +111,40 @@ namespace native { #include namespace fe = cudnn_frontend; +using graph_and_tensors = std::tuple< + std::shared_ptr, + std::shared_ptr, // Q, + std::shared_ptr, // K, + std::shared_ptr, // V, + std::optional>, // Bias + std::shared_ptr, // Attn_scale, + // TODO(eqy): additional options + // std::shared_ptr, // SEQ_LEN_Q, + // std::shared_ptr, // SEQ_LEN_KV, + std::shared_ptr, // Seed, + std::shared_ptr, // Offset, + // std::shared_ptr, // Dropout_mask, + // std::shared_ptr, // Dropout_scale + std::shared_ptr, // O + std::shared_ptr // Stats + >; + +using graph_and_tensors_backward = std::tuple< + std::shared_ptr, + std::shared_ptr, // Q, + std::shared_ptr, // K, + std::shared_ptr, // V, + std::optional>, // Bias, + std::shared_ptr, // Attn_scale, + std::shared_ptr, // Seed, + std::shared_ptr, // Offset, + std::shared_ptr, // O, + std::shared_ptr, // dO, + std::shared_ptr, // stats, + std::shared_ptr, // dQ, + std::shared_ptr, // dK,, + std::shared_ptr // dV, + >; #define MAX_MHA_DIM 4 @@ -295,45 +298,11 @@ struct MHAGraphCache { // @eqy: use thread local caches as cuDNN Execution Plans are not guaranteed to // be thread safe across all engines see Limitations in // https://docs.nvidia.com/deeplearning/cudnn/backend/latest/release-notes.html -// We also leak the caches to workaround potential teardown race issues. - -auto& getMHAGraphCache_() { - thread_local auto& instance = - *new MHAGraphCache, MHACacheKeyWrapper>; - return instance; -} - -auto& getMHAGraphBackwardCache_() { - thread_local auto& instance = - *new MHAGraphCache, MHACacheKeyWrapper>; - return instance; -} +thread_local MHAGraphCache mhagraphcache; +thread_local MHAGraphCache + mhagraphbackwardcache; namespace { - -enum UIDS { - Q, - K, - V, - O, - BIAS, - SCALE, - SEED, - OFFSET, - LSE, - DO, - DQ, - DK, - DV, - SEQ_LEN_Q, - SEQ_LEN_KV, - RAG_Q_OFF, - RAG_K_OFF, - RAG_V_OFF, - RAG_O_OFF, - RAG_LSE_OFF -}; - // analogous to the same function in Descriptors.h for cuDNN Convolutions... auto fixSizeOneDimStrideSDPA( const IntArrayRef sizes, @@ -351,10 +320,9 @@ auto fixSizeOneDimStrideSDPA( } return strides; } - } // namespace -auto build_graph( +auto build_graph_and_tensors( int64_t b, int64_t h, int64_t s_q, @@ -387,55 +355,46 @@ auto build_graph( .set_compute_data_type(fe::DataType_t::FLOAT); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SCALE) .set_name("Attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - auto scaled_dot_product_flash_attention_options = - fe::graph::SDPA_attributes() - .set_name("CUDNN_SDPA") - .set_is_inference(return_softmaxstats == false) - .set_causal_mask(is_causal) - .set_attn_scale(attn_scale); - if (dropout_probability != 0.0f) { - auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SEED) - .set_name("Seed") + auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutseed.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Offset") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type( - dropoutseed.dtype() == kInt + dropoutoffset.dtype() == kInt ? fe::DataType_t::INT32 : fe::DataType_t::INT64)); - auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(OFFSET) - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutoffset.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); - scaled_dot_product_flash_attention_options.set_dropout( - dropout_probability, seed, offset); - } - auto Q_ = mha_graph->tensor( + auto scaled_dot_product_flash_attention_options = + fe::graph::SDPA_attributes() + .set_name("CUDNN_SDPA") + .set_is_inference(return_softmaxstats == false) + .set_causal_mask(is_causal) + .set_attn_scale(attn_scale) + .set_dropout(dropout_probability, seed, offset); + auto Q = mha_graph->tensor( fe::graph::Tensor_attributes() - .set_uid(Q) .set_name("Q") .set_dim(q.sizes().vec()) .set_stride(fixSizeOneDimStrideSDPA(q.sizes(), q.strides().vec()))); - auto K_ = mha_graph->tensor( + auto K = mha_graph->tensor( fe::graph::Tensor_attributes() - .set_uid(K) .set_name("K") .set_dim(k.sizes().vec()) .set_stride(fixSizeOneDimStrideSDPA(k.sizes(), k.strides().vec()))); - auto V_ = mha_graph->tensor( + auto V = mha_graph->tensor( fe::graph::Tensor_attributes() - .set_uid(V) .set_name("V") .set_dim(v.sizes().vec()) .set_stride(fixSizeOneDimStrideSDPA(v.sizes(), v.strides().vec()))); @@ -443,20 +402,17 @@ auto build_graph( if (attn_bias.has_value()) { bias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(BIAS) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) .set_stride(attn_bias.value().strides().vec())); scaled_dot_product_flash_attention_options.set_bias(bias.value()); } - auto [O_, Stats] = - mha_graph->sdpa(Q_, K_, V_, scaled_dot_product_flash_attention_options); - O_->set_uid(O); - O_->set_output(true).set_dim(o.sizes().vec()).set_stride(o.strides().vec()); + auto [O, Stats] = + mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options); + O->set_output(true).set_dim(o.sizes().vec()).set_stride(o.strides().vec()); if (Stats) { - Stats->set_uid(LSE); Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT); } @@ -467,10 +423,20 @@ auto build_graph( AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); - return mha_graph; + return std::make_tuple( + std::move(mha_graph), + std::move(Q), + std::move(K), + std::move(V), + std::move(bias), + std::move(attn_scale), + std::move(seed), + std::move(offset), + std::move(O), + std::move(Stats)); } -auto build_graph_nestedtensor( +auto build_graph_and_tensors_nestedtensor( int64_t b, int64_t h_q, int64_t h_k, @@ -507,22 +473,28 @@ auto build_graph_nestedtensor( .set_compute_data_type(fe::DataType_t::FLOAT); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SCALE) .set_name("Attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - auto SEQ_LEN_Q_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SEQ_LEN_Q) - .set_name("Seq_q") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto SEQ_LEN_KV_ = + auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto SEQ_LEN_Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto SEQ_LEN_KV = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SEQ_LEN_KV) .set_name("Seq_kv") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -534,66 +506,41 @@ auto build_graph_nestedtensor( .set_is_inference(return_softmaxstats == false) .set_causal_mask(is_causal) .set_attn_scale(attn_scale) - .set_seq_len_q(SEQ_LEN_Q_) - .set_seq_len_kv(SEQ_LEN_KV_) + .set_dropout(dropout_probability, seed, offset) + .set_seq_len_q(SEQ_LEN_Q) + .set_seq_len_kv(SEQ_LEN_KV) .set_padding_mask(true); - if (dropout_probability != 0.0f) { - auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SEED) - .set_name("Seed") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutseed.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); - auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(OFFSET) - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutoffset.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); - scaled_dot_product_flash_attention_options.set_dropout( - dropout_probability, seed, offset); - } // We hardcode BSHD to cuDNN even though the underlying layout is THD auto q_strides = q.strides(); auto k_strides = k.strides(); auto v_strides = v.strides(); - // NB: cuDNN API shape is transposed constexpr int strideidx0 = 1; constexpr int strideidx1 = 0; constexpr int strideidx2 = 2; - auto Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(Q) - .set_name("Q") - .set_dim({b, h_q, s_q, d_qk}) - .set_stride( - {INT_MAX, - q_strides[strideidx0], - q_strides[strideidx1], - q_strides[strideidx2]})); - auto K_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(K) - .set_name("K") - .set_dim({b, h_k, s_kv, d_qk}) - .set_stride( - {INT_MAX, - k_strides[strideidx0], - k_strides[strideidx1], - k_strides[strideidx2]})); - auto V_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(V) - .set_name("V") - .set_dim({b, h_v, s_kv, d_v}) - .set_stride( - {INT_MAX, - v_strides[strideidx0], - v_strides[strideidx1], - v_strides[strideidx2]})); + auto Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h_q, s_q, d_qk}) + .set_stride( + {INT_MAX, + q_strides[strideidx0], + q_strides[strideidx1], + q_strides[strideidx2]})); + auto K = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({b, h_k, s_kv, d_qk}) + .set_stride( + {INT_MAX, + k_strides[strideidx0], + k_strides[strideidx1], + k_strides[strideidx2]})); + auto V = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({b, h_v, s_kv, d_v}) + .set_stride( + {INT_MAX, + v_strides[strideidx0], + v_strides[strideidx1], + v_strides[strideidx2]})); std::optional> bias; if (attn_bias.has_value()) { TORCH_CHECK( @@ -601,48 +548,44 @@ auto build_graph_nestedtensor( "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); bias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(BIAS) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) .set_stride(attn_bias.value().strides().vec())); scaled_dot_product_flash_attention_options.set_bias(bias.value()); } - auto RAG_Q_OFF_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_Q_OFF) - .set_name("cum_seq_q") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_K_OFF_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_K_OFF) - .set_name("cum_seq_k") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_V_OFF_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_V_OFF) - .set_name("cum_seq_v") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_O_OFF_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_O_OFF) - .set_name("cum_seq_o") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - Q_->set_ragged_offset(RAG_Q_OFF_); - K_->set_ragged_offset(RAG_K_OFF_); - V_->set_ragged_offset(RAG_V_OFF_); - auto [O_, Stats] = - mha_graph->sdpa(Q_, K_, V_, scaled_dot_product_flash_attention_options); + auto RAG_Q_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("cum_seq_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_K_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("cum_seq_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_V_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("cum_seq_v") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto RAG_O_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("cum_seq_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + // auto RAG_STATS_OFF = mha_graph->tensor(fe::graph::Tensor_attributes() + // .set_name("cum_seq_stats") + // .set_dim({b + 1, 1, 1, 1}) + // .set_stride({1, 1, 1, 1}) + // .set_data_type(fe::DataType_t::INT32)); + auto RAG_STATS_OFF = nullptr; + Q->set_ragged_offset(RAG_Q_OFF); + K->set_ragged_offset(RAG_K_OFF); + V->set_ragged_offset(RAG_V_OFF); + auto [O, Stats] = + mha_graph->sdpa(Q, K, V, scaled_dot_product_flash_attention_options); auto o_strides = o.strides(); - O_->set_output(true) - .set_uid(O) + O->set_output(true) .set_dim({b, h_q, s_q, d_v}) .set_stride( {INT_MAX, @@ -650,20 +593,16 @@ auto build_graph_nestedtensor( o_strides[strideidx1], o_strides[strideidx2]}); - O_->set_ragged_offset(RAG_O_OFF_); + O->set_ragged_offset(RAG_O_OFF); if (Stats) { - auto RAG_STATS_OFF = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_LSE_OFF) - .set_name("cum_seq_stats") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); + TORCH_CHECK( + false, + "cuDNN SDPA Nested Tensor does not yet handle backwards/logsumexp computation"); + // TODO(eqy): fix when stats (backward) support is added Stats->set_output(true) - .set_uid(LSE) .set_data_type(fe::DataType_t::FLOAT) .set_dim({b, h_q, s_q, 1}) - .set_stride({h_q * s_q, 1, h_q, 1}); + .set_stride({h_q * s_q * d_v, d_v, s_q * d_v, 1}); Stats->set_ragged_offset(RAG_STATS_OFF); } AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); @@ -672,10 +611,27 @@ auto build_graph_nestedtensor( mha_graph->create_execution_plans({fe::HeurMode_t::A})); AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); - return mha_graph; + return std::make_tuple( + std::move(mha_graph), + std::move(Q), + std::move(K), + std::move(V), + std::move(bias), + std::move(attn_scale), + std::move(seed), + std::move(offset), + std::move(O), + std::move(Stats), + std::move(RAG_Q_OFF), + std::move(RAG_K_OFF), + std::move(RAG_V_OFF), + std::move(RAG_O_OFF), + std::move(RAG_STATS_OFF), + std::move(SEQ_LEN_Q), + std::move(SEQ_LEN_KV)); } -auto build_graph_backward( +auto build_graph_and_tensors_backward( int64_t b, int64_t h, int64_t s_q, @@ -711,7 +667,6 @@ auto build_graph_backward( .set_compute_data_type(fe::DataType_t::FLOAT); auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SCALE) .set_name("Attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -721,327 +676,87 @@ auto build_graph_backward( .set_name("CUDNN_SDPA_BACKWARD") .set_causal_mask(is_causal) .set_attn_scale(attn_scale); - auto Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(Q) - .set_name("Q") - .set_dim(q.sizes().vec()) - .set_stride(q.strides().vec())); - auto K_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(K) - .set_name("K") - .set_dim(k.sizes().vec()) - .set_stride(k.strides().vec())); - auto V_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(V) - .set_name("V") - .set_dim(v.sizes().vec()) - .set_stride(v.strides().vec())); + auto Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim(q.sizes().vec()) + .set_stride(q.strides().vec())); + auto K = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim(k.sizes().vec()) + .set_stride(k.strides().vec())); + auto V = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim(v.sizes().vec()) + .set_stride(v.strides().vec())); std::optional> bias; if (attn_bias.has_value()) { bias = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(BIAS) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) .set_stride(attn_bias.value().strides().vec())); sdpa_backward_options.set_bias(bias.value()); } - if (dropout_probability != 0.0f) { - auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SEED) - .set_name("Seed") + auto Seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type( + dropoutseed.dtype() == kInt + ? fe::DataType_t::INT32 + : fe::DataType_t::INT64)); + + auto Offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Offset") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type( - dropoutseed.dtype() == kInt + dropoutoffset.dtype() == kInt ? fe::DataType_t::INT32 : fe::DataType_t::INT64)); - auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(OFFSET) - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutoffset.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); - sdpa_backward_options.set_dropout(dropout_probability, seed, offset); - } - auto O_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(O) - .set_name("O") - .set_dim(o.sizes().vec()) - .set_stride(o.strides().vec())); - auto Stats = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(LSE) + auto O = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("O") + .set_dim(o.sizes().vec()) + .set_stride(o.strides().vec())); + auto STATS = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Stats") .set_dim(softmaxstats.sizes().vec()) .set_stride(softmaxstats.strides().vec()) .set_data_type(fe::DataType_t::FLOAT)); - auto Do = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(DO) + auto DO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("DO") .set_dim(dO.sizes().vec()) .set_stride(dO.strides().vec())); - auto [Dq, Dk, Dv] = mha_graph->sdpa_backward( - Q_, K_, V_, O_, Do, Stats, sdpa_backward_options); - Dq->set_uid(DQ); - Dq->set_output(true).set_dim(dQ.sizes().vec()).set_stride(dQ.strides().vec()); - Dk->set_uid(DK); - Dk->set_output(true).set_dim(dK.sizes().vec()).set_stride(dK.strides().vec()); - Dv->set_uid(DV); - Dv->set_output(true).set_dim(dV.sizes().vec()).set_stride(dV.strides().vec()); - AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); - AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle)); - AT_CUDNN_FRONTEND_CHECK( - mha_graph->create_execution_plans({fe::HeurMode_t::A})); - AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); - AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); - return mha_graph; -} - -auto build_graph_backward_nestedtensor( - int64_t b, - int64_t h_q, - int64_t h_k, - int64_t h_v, - int64_t s_q, - int64_t s_kv, - int64_t d_qk, - int64_t d_v, - float scaling_factor, - bool is_causal, - float dropout_probability, - const Tensor& cum_seqlen_q, - const Tensor& cum_seqlen_kv, - const Tensor& q, - const Tensor& k, - const Tensor& v, - const std::optional& attn_bias, - const Tensor& o, - const Tensor& dO, - const Tensor& softmaxstats, - Tensor& dQ, - Tensor& dK, - Tensor& dV, - const Tensor& dropoutseed, - const Tensor& dropoutoffset, - cudnnHandle_t& handle) { - auto dtype = fe::DataType_t::HALF; - if (q.scalar_type() == kBFloat16) { - dtype = fe::DataType_t::BFLOAT16; - } - auto mha_graph = std::make_shared(); - // We're baking in float accumulation and scale types - // in theory the graph may support other types, but they - // have not been tested - mha_graph->set_io_data_type(dtype) - .set_intermediate_data_type(fe::DataType_t::FLOAT) - .set_compute_data_type(fe::DataType_t::FLOAT); - auto attn_scale = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SCALE) - .set_name("Attn_scale") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)); - - auto SEQ_LEN_Q_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SEQ_LEN_Q) - .set_name("Seq_q") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto SEQ_LEN_KV_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SEQ_LEN_KV) - .set_name("Seq_kv") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto sdpa_backward_options = fe::graph::SDPA_backward_attributes() - .set_name("CUDNN_SDPA_NESTEDTENSOR_BACKWARD") - .set_causal_mask(is_causal) - .set_attn_scale(attn_scale) - .set_seq_len_q(SEQ_LEN_Q_) - .set_seq_len_kv(SEQ_LEN_KV_) - .set_padding_mask(true); if (dropout_probability != 0.0f) { - auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(SEED) - .set_name("Seed") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutseed.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); - auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(OFFSET) - .set_name("Offset") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutoffset.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); - sdpa_backward_options.set_dropout(dropout_probability, seed, offset); + sdpa_backward_options.set_dropout(dropout_probability, Seed, Offset); } - auto q_strides = q.strides(); - auto k_strides = k.strides(); - auto v_strides = v.strides(); - // NB: cuDNN API shape is transposed - constexpr int strideidx0 = 1; - constexpr int strideidx1 = 0; - constexpr int strideidx2 = 2; - auto Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(Q) - .set_name("Q") - .set_dim({b, h_q, s_q, d_qk}) - .set_stride( - {INT_MAX, - q_strides[strideidx0], - q_strides[strideidx1], - q_strides[strideidx2]})); - auto K_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(K) - .set_name("K") - .set_dim({b, h_k, s_kv, d_qk}) - .set_stride( - {INT_MAX, - k_strides[strideidx0], - k_strides[strideidx1], - k_strides[strideidx2]})); - auto V_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(V) - .set_name("V") - .set_dim({b, h_v, s_kv, d_v}) - .set_stride( - {INT_MAX, - v_strides[strideidx0], - v_strides[strideidx1], - v_strides[strideidx2]})); - auto o_strides = o.strides(); - auto O_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(O) - .set_name("O") - .set_dim({b, h_q, s_q, d_v}) - .set_stride( - {INT_MAX, - o_strides[strideidx0], - o_strides[strideidx1], - o_strides[strideidx2]})); - - std::optional> bias; - if (attn_bias.has_value()) { - TORCH_CHECK( - false, - "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); - bias = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(BIAS) - .set_name("bias") - .set_dim(attn_bias.value().sizes().vec()) - .set_stride(attn_bias.value().strides().vec())); - sdpa_backward_options.set_bias(bias.value()); - } - auto RAG_Q_OFF_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_Q_OFF) - .set_name("cum_seq_q") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_K_OFF_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_K_OFF) - .set_name("cum_seq_k") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_V_OFF_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_V_OFF) - .set_name("cum_seq_v") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_O_OFF_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_O_OFF) - .set_name("cum_seq_o") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - auto RAG_STATS_OFF_ = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(RAG_LSE_OFF) - .set_name("cum_seq_stats") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - O_->set_ragged_offset(RAG_O_OFF_); - Q_->set_ragged_offset(RAG_Q_OFF_); - K_->set_ragged_offset(RAG_K_OFF_); - V_->set_ragged_offset(RAG_V_OFF_); - auto STATS = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_uid(LSE) - .set_name("stats") - .set_dim({b, h_q, s_q, 1}) - .set_stride({s_q * h_q, 1, h_q, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - STATS->set_ragged_offset(RAG_STATS_OFF_); - auto do_strides = dO.strides(); - auto DO_ = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_ragged_offset(RAG_O_OFF_) - .set_uid(DO) - .set_name("DO") - .set_dim({b, h_q, s_q, d_v}) - .set_stride( - {INT_MAX, - do_strides[strideidx0], - do_strides[strideidx1], - do_strides[strideidx2]})); - auto [Dq, Dk, Dv] = mha_graph->sdpa_backward( - Q_, K_, V_, O_, DO_, STATS, sdpa_backward_options); - Dq->set_output(true) - .set_uid(DQ) - .set_ragged_offset(RAG_Q_OFF_) - .set_dim({b, h_q, s_q, d_qk}) - .set_stride( - {INT_MAX, - q_strides[strideidx0], - q_strides[strideidx1], - q_strides[strideidx2]}); - Dk->set_output(true) - .set_uid(DK) - .set_ragged_offset(RAG_K_OFF_) - .set_dim({b, h_k, s_kv, d_qk}) - .set_stride( - {INT_MAX, - k_strides[strideidx0], - k_strides[strideidx1], - k_strides[strideidx2]}); - Dv->set_output(true) - .set_uid(DV) - .set_ragged_offset(RAG_V_OFF_) - .set_dim({b, h_v, s_kv, d_v}) - .set_stride( - {INT_MAX, - v_strides[strideidx0], - v_strides[strideidx1], - v_strides[strideidx2]}); - + auto [DQ, DK, DV] = + mha_graph->sdpa_backward(Q, K, V, O, DO, STATS, sdpa_backward_options); + DQ->set_output(true).set_dim(dQ.sizes().vec()).set_stride(dQ.strides().vec()); + DK->set_output(true).set_dim(dK.sizes().vec()).set_stride(dK.strides().vec()); + DV->set_output(true).set_dim(dV.sizes().vec()).set_stride(dV.strides().vec()); AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle)); AT_CUDNN_FRONTEND_CHECK( mha_graph->create_execution_plans({fe::HeurMode_t::A})); AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); - return mha_graph; + return std::make_tuple( + std::move(mha_graph), + std::move(Q), + std::move(K), + std::move(V), + std::move(bias), + std::move(attn_scale), + std::move(Seed), + std::move(Offset), + std::move(O), + std::move(DO), + std::move(STATS), + std::move(DQ), + std::move(DK), + std::move(DV)); } void run_cudnn_SDP_fprop( @@ -1102,12 +817,12 @@ void run_cudnn_SDP_fprop( dropout_probability, is_causal, return_softmaxstats); - auto graph_ptr = getMHAGraphCache_().find(key); - std::shared_ptr mha_graph; - if (graph_ptr) { - mha_graph = *graph_ptr; + auto graph_and_tensors_ptr = mhagraphcache.find(key); + graph_and_tensors graph_and_tensors_values; + if (graph_and_tensors_ptr) { + graph_and_tensors_values = *graph_and_tensors_ptr; } else { - mha_graph = build_graph( + graph_and_tensors_values = build_graph_and_tensors( b, h, s_q, @@ -1128,28 +843,29 @@ void run_cudnn_SDP_fprop( _dropoutoffset, handle); } - std::unordered_map variant_pack = { - {Q, q.data_ptr()}, - {K, k.data_ptr()}, - {V, v.data_ptr()}, - {SCALE, &scaling_factor}, - {O, o.data_ptr()}}; + auto [mha_graph, Q, K, V, bias, attn_scale, seed, offset, O, Stats] = + graph_and_tensors_values; + std::unordered_map, void*> + variant_pack = { + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {attn_scale, &scaling_factor}, + {seed, _dropoutseed.data_ptr()}, + {offset, _dropoutoffset.data_ptr()}, + {O, o.data_ptr()}}; if (return_softmaxstats) { - variant_pack[LSE] = softmaxstats.data_ptr(); + variant_pack[Stats] = softmaxstats.data_ptr(); } if (attn_bias.has_value()) { - variant_pack[BIAS] = attn_bias.value().data_ptr(); - } - if (dropout_probability != 0.0f) { - variant_pack[SEED] = _dropoutseed.data_ptr(); - variant_pack[OFFSET] = _dropoutoffset.data_ptr(); + variant_pack[bias.value()] = attn_bias.value().data_ptr(); } auto workspace_size = mha_graph->get_workspace_size(); auto workspace_ptr = c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); TORCH_CHECK( mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good()); - getMHAGraphCache_().update(key, mha_graph); + mhagraphcache.update(key, graph_and_tensors_values); } void run_cudnn_SDP_fprop_nestedtensor( @@ -1188,55 +904,72 @@ void run_cudnn_SDP_fprop_nestedtensor( if (return_softmaxstats && !softmaxstats.defined()) { softmaxstats = at::empty({q.size(0), h_q, 1}, q.options().dtype(kFloat)); } - auto mha_graph = build_graph_nestedtensor( - b, - h_q, - h_k, - h_v, - s_q, - s_kv, - d_qk, - d_v, - scaling_factor, - return_softmaxstats, - is_causal, - dropout_probability, - cum_seqlen_q, - cum_seqlen_kv, - q, - k, - v, - attn_bias, - softmaxstats, - o, - dropoutseed, - dropoutoffset, - handle); + auto + [mha_graph, + Q, + K, + V, + bias, + attn_scale, + seed, + offset, + O, + Stats, + RAG_Q_OFF, + RAG_K_OFF, + RAG_V_OFF, + RAG_O_OFF, + RAG_STATS_OFF, + SEQ_LEN_Q, + SEQ_LEN_KV] = + build_graph_and_tensors_nestedtensor( + b, + h_q, + h_k, + h_v, + s_q, + s_kv, + d_qk, + d_v, + scaling_factor, + return_softmaxstats, + is_causal, + dropout_probability, + cum_seqlen_q, + cum_seqlen_kv, + q, + k, + v, + attn_bias, + softmaxstats, + o, + dropoutseed, + dropoutoffset, + handle); auto seqlen_q = at::diff(cum_seqlen_q, 1, 0); auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0); auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk); - auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v); + auto rag_k_off = cum_seqlen_kv.mul(h_k * d_qk); auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v); auto rag_stats_off = cum_seqlen_q.mul(h_q); - std::unordered_map variant_pack = { - {Q, q.data_ptr()}, - {K, k.data_ptr()}, - {V, v.data_ptr()}, - {SCALE, &scaling_factor}, - {O, o.data_ptr()}, - {RAG_Q_OFF, rag_q_off.data_ptr()}, - {RAG_O_OFF, rag_q_off.data_ptr()}, - {RAG_K_OFF, rag_k_off.data_ptr()}, - {RAG_V_OFF, rag_v_off.data_ptr()}, - {SEQ_LEN_Q, seqlen_q.data_ptr()}, - {SEQ_LEN_KV, seqlen_kv.data_ptr()}}; + std::unordered_map, void*> + variant_pack = { + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {attn_scale, &scaling_factor}, + {seed, dropoutseed.data_ptr()}, + {offset, dropoutoffset.data_ptr()}, + {O, o.data_ptr()}, + {RAG_Q_OFF, rag_q_off.data_ptr()}, + {RAG_O_OFF, rag_q_off.data_ptr()}, + {RAG_K_OFF, rag_k_off.data_ptr()}, + {RAG_V_OFF, rag_v_off.data_ptr()}, + {SEQ_LEN_Q, seqlen_q.data_ptr()}, + {SEQ_LEN_KV, seqlen_kv.data_ptr()}}; if (return_softmaxstats) { - variant_pack[LSE] = softmaxstats.data_ptr(); - variant_pack[RAG_LSE_OFF] = rag_stats_off.data_ptr(); - } - if (dropout_probability != 0.0f) { - variant_pack[SEED] = dropoutseed.data_ptr(); - variant_pack[OFFSET] = dropoutoffset.data_ptr(); + variant_pack[Stats] = softmaxstats.data_ptr(); + variant_pack[RAG_STATS_OFF] = cum_seqlen_q.data_ptr(); } if (attn_bias.has_value()) { TORCH_CHECK("bias not supported with nestedtensor"); @@ -1320,12 +1053,12 @@ void run_cudnn_SDP_bprop( dropout_probability, is_causal, true); - auto graph_backward_ptr = getMHAGraphBackwardCache_().find(key); - std::shared_ptr mha_graph; - if (graph_backward_ptr) { - mha_graph = *graph_backward_ptr; + auto graph_and_tensors_backward_ptr = mhagraphbackwardcache.find(key); + graph_and_tensors_backward graph_and_tensors_backward_values; + if (graph_and_tensors_backward_ptr) { + graph_and_tensors_backward_values = *graph_and_tensors_backward_ptr; } else { - mha_graph = build_graph_backward( + graph_and_tensors_backward_values = build_graph_and_tensors_backward( b, h, s_q, @@ -1349,25 +1082,41 @@ void run_cudnn_SDP_bprop( _dropoutoffset, handle); } - std::unordered_map variant_pack = { - // inputs - {Q, q.data_ptr()}, - {K, k.data_ptr()}, - {V, v.data_ptr()}, - {O, o.data_ptr()}, - {DO, dO_.data_ptr()}, - {LSE, softmaxstats.data_ptr()}, - // outputs - {DQ, dQ.data_ptr()}, - {DK, dK.data_ptr()}, - {DV, dV.data_ptr()}, - {SCALE, &scaling_factor}}; + auto + [mha_graph, + Q, + K, + V, + bias, + attn_scale, + Seed, + Offset, + O, + Do, + Stats, + Dq, + Dk, + Dv] = graph_and_tensors_backward_values; + std::unordered_map, void*> + variant_pack = {// inputs + {Q, q.data_ptr()}, + {K, k.data_ptr()}, + {V, v.data_ptr()}, + {O, o.data_ptr()}, + {Do, dO_.data_ptr()}, + {Stats, softmaxstats.data_ptr()}, + // outputs + {Dq, dQ.data_ptr()}, + {Dk, dK.data_ptr()}, + {Dv, dV.data_ptr()}, + // pass by value + {attn_scale, &scaling_factor}}; if (dropout_probability != 0.0f) { - variant_pack[SEED] = _dropoutseed.data_ptr(); - variant_pack[OFFSET] = _dropoutoffset.data_ptr(); + variant_pack[Seed] = _dropoutseed.data_ptr(); + variant_pack[Offset] = _dropoutoffset.data_ptr(); } if (attn_bias.has_value()) { - variant_pack[BIAS] = attn_bias.value().data_ptr(); + variant_pack[bias.value()] = attn_bias.value().data_ptr(); } auto workspace_size = mha_graph->get_workspace_size(); auto workspace_ptr = @@ -1375,127 +1124,7 @@ void run_cudnn_SDP_bprop( TORCH_CHECK(!workspace_size || workspace_ptr.get()); TORCH_CHECK( mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good()); - getMHAGraphBackwardCache_().update(key, mha_graph); -} - -void run_cudnn_SDP_bprop_nestedtensor( - int64_t b, - int64_t h_q, - int64_t h_k, - int64_t h_v, - int64_t s_q, - int64_t s_kv, - int64_t d_qk, - int64_t d_v, - float scaling_factor, - bool is_causal, - float dropout_probability, - const Tensor& cum_seqlen_q, - const Tensor& cum_seqlen_kv, - const Tensor& q, - const Tensor& k, - const Tensor& v, - const std::optional& attn_bias, - const Tensor& o, - const Tensor& dO, - const Tensor& softmaxstats, - Tensor& dQ, - Tensor& dK, - Tensor& dV, - const Tensor& dropoutseed, - const Tensor& dropoutoffset) { - // do nothing if we got 0-element tensors - if (!q.numel() || !k.numel() || !v.numel() || !o.numel() || !dO.numel() || - !softmaxstats.numel()) { - return; - } - - Tensor dO_ = dO; - const auto innermost_dO_stride = dO.strides()[dO.strides().size() - 1]; - if (innermost_dO_stride != 1) { - permute_to_matching_layout(o, dO_); - } - - auto seqlen_q = at::diff(cum_seqlen_q, 1, 0); - auto seqlen_kv = at::diff(cum_seqlen_kv, 1, 0); - auto rag_q_off = cum_seqlen_q.mul(h_q * d_qk); - auto rag_k_off = cum_seqlen_kv.mul(h_k * d_v); - auto rag_v_off = cum_seqlen_kv.mul(h_v * d_v); - auto rag_stats_off = cum_seqlen_q.mul(h_q); - - auto dprops = at::cuda::getCurrentDeviceProperties(); - auto _dropoutseed = dropoutseed; - auto _dropoutoffset = dropoutoffset; - // cuDNN dropout bug requires these to be in int64 - if (dprops->major == 10 && dprops->minor == 0) { - _dropoutseed = dropoutseed.to(kLong); - _dropoutoffset = dropoutoffset.to(kLong); - } - - cudnnHandle_t handle = getCudnnHandle(); - - auto mha_graph = build_graph_backward_nestedtensor( - b, - h_q, - h_k, - h_v, - s_q, - s_kv, - d_qk, - d_v, - scaling_factor, - is_causal, - dropout_probability, - cum_seqlen_q, - cum_seqlen_kv, - q, - k, - v, - attn_bias, - o, - dO_, - softmaxstats, - dQ, - dK, - dV, - dropoutseed, - dropoutoffset, - handle); - - std::unordered_map variant_pack = { - // inputs - {Q, q.data_ptr()}, - {K, k.data_ptr()}, - {V, v.data_ptr()}, - {O, o.data_ptr()}, - {DO, dO_.data_ptr()}, - {LSE, softmaxstats.data_ptr()}, - // outputs - {DQ, dQ.data_ptr()}, - {DK, dK.data_ptr()}, - {DV, dV.data_ptr()}, - {SCALE, &scaling_factor}, - {RAG_Q_OFF, rag_q_off.data_ptr()}, - {RAG_O_OFF, rag_q_off.data_ptr()}, - {RAG_K_OFF, rag_k_off.data_ptr()}, - {RAG_V_OFF, rag_v_off.data_ptr()}, - {RAG_LSE_OFF, rag_stats_off.data_ptr()}, - {SEQ_LEN_Q, seqlen_q.data_ptr()}, - {SEQ_LEN_KV, seqlen_kv.data_ptr()}}; - if (dropout_probability != 0.0f) { - variant_pack[SEED] = _dropoutseed.data_ptr(); - variant_pack[OFFSET] = _dropoutoffset.data_ptr(); - } - TORCH_CHECK( - !attn_bias.has_value(), - "attn_bias not yet supportd with cuDNN Attention and NestedTensor"); - - auto workspace_size = mha_graph->get_workspace_size(); - auto workspace_ptr = - c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); - TORCH_CHECK(!workspace_size || workspace_ptr.get()); - TORCH_CHECK( - mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good()); + mhagraphbackwardcache.update(key, graph_and_tensors_backward_values); } } // namespace native diff --git a/aten/src/ATen/native/cudnn/MHA.h b/aten/src/ATen/native/cudnn/MHA.h index 620abc1aa0a8..045e8cf6dee9 100644 --- a/aten/src/ATen/native/cudnn/MHA.h +++ b/aten/src/ATen/native/cudnn/MHA.h @@ -70,31 +70,4 @@ void run_cudnn_SDP_bprop( const Tensor& dropoutseed, const Tensor& dropoutoffset); -void run_cudnn_SDP_bprop_nestedtensor( - int64_t b, - int64_t h_q, - int64_t h_k, - int64_t h_v, - int64_t s_q, - int64_t s_kv, - int64_t d_qk, - int64_t d_v, - float scaling_factor, - bool is_causal, - float dropout_probability, - const Tensor& cum_seqlen_q, - const Tensor& cum_seqlen_kv, - const Tensor& q, - const Tensor& k, - const Tensor& v, - const std::optional& attn_bias, - const Tensor& o, - const Tensor& dO, - const Tensor& softmaxstats, - Tensor& dQ, - Tensor& dK, - Tensor& dV, - const Tensor& dropoutseed, - const Tensor& dropoutoffset); - } // namespace at::native diff --git a/aten/src/ATen/native/layer_norm.cpp b/aten/src/ATen/native/layer_norm.cpp index da6bb5fec39e..950aa99a9aab 100644 --- a/aten/src/ATen/native/layer_norm.cpp +++ b/aten/src/ATen/native/layer_norm.cpp @@ -261,30 +261,11 @@ std::tuple math_native_layer_norm( return outputs; } -Tensor rms_norm_symint( +std::tuple rms_norm_composite( const Tensor& input, - c10::SymIntArrayRef normalized_shape, + IntArrayRef normalized_shape, const std::optional& weight_opt /* optional */, std::optional eps) { - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); - const Tensor& weight = *weight_maybe_owned; - _check_rms_norm_inputs_symint(input, normalized_shape, weight); - -#ifdef USE_MPS - if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) { - const Tensor weight = weight_opt.value(); - const bool any_nested = input.is_nested() || weight.is_nested(); - const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad(); - const bool is_input_fp = isFloatingType(input.scalar_type()); - const bool is_weight_fp = isFloatingType(weight.scalar_type()); - - if (!(GradMode::is_enabled() && any_inputs_require_grad) && !any_nested && is_input_fp && is_weight_fp) { - auto eps_val = eps.value_or(std::numeric_limits::epsilon()); - return at::_fused_rms_norm(input.contiguous(), normalized_shape.size(), weight.contiguous(), eps_val); - } - } -#endif std::vector dims_to_reduce; for (const auto i : c10::irange(normalized_shape.size())) { @@ -321,10 +302,67 @@ Tensor rms_norm_symint( upcasted_result = upcasted_result.mul(weight_opt.value()); } - return upcasted_result; + // if nested do not make contiguous + if(input.is_nested() || (weight_opt.has_value() && weight_opt.value().is_nested())){ + return std::make_tuple(upcasted_result, rqrst_input); + } + + if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){ + return std::make_tuple(upcasted_result, rqrst_input); + } + + return std::make_tuple(upcasted_result.contiguous(), rqrst_input.contiguous()); }); + return std::make_tuple( + std::get<0>(result).type_as(input), // Cast normalized result to original input type + std::get<1>(result) // rsqrt_val + ); +} + + +Tensor rms_norm_symint( + const Tensor& input, + c10::SymIntArrayRef normalized_shape, + const std::optional& weight_opt /* optional */, + const std::optional eps) { + + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + _check_rms_norm_inputs_symint(input, normalized_shape, weight); + + // composite fallback for channels last + if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){ + return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); + } - return result.type_as(input); + // composite fallback for complex datatypes + if(input.is_complex()){ + return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); + } + if (weight_opt.has_value() && weight_opt.value().defined() && weight_opt.value().dtype() != input.dtype()) { + TORCH_WARN_ONCE( + "Mismatch dtype between input and module: input dtype = ", input.dtype(), + ", module dtype = ", weight_opt.value().dtype(), ", Can not dispatch to fused implementation" + ); + return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); + } + + #ifdef USE_MPS + if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) { + const Tensor weight = weight_opt.value(); + const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad(); + + if (!(GradMode::is_enabled() && any_inputs_require_grad)) { + return std::get<0>(at::_fused_rms_norm(input.contiguous(), IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); + } + } + + if (input.device().type() == DeviceType::MPS){ + return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); + } + #endif + return std::get<0>(at::_fused_rms_norm(input, IntArrayRef(reinterpret_cast(normalized_shape.data()), normalized_shape.size()), weight_opt, eps)); } + } // namespace at::native diff --git a/aten/src/ATen/native/layer_norm.h b/aten/src/ATen/native/layer_norm.h index 0181f35fd6ed..0debe942dd0a 100644 --- a/aten/src/ATen/native/layer_norm.h +++ b/aten/src/ATen/native/layer_norm.h @@ -106,6 +106,12 @@ void layer_norm_cpu_out( int64_t M, int64_t N); +std::tuple rms_norm_composite( + const Tensor& input, + IntArrayRef normalized_shape, + const std::optional& weight_opt /* optional */, + std::optional eps); + Tensor rms_norm_symint( const Tensor& input, c10::SymIntArrayRef normalized_shape, diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp index 1e2993e79f4d..8222304e6d07 100644 --- a/aten/src/ATen/native/mkldnn/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/Conv.cpp @@ -160,6 +160,10 @@ static bool mkldnn_conv_enabled_fpmath_mode_bf16(){ mkldnn_bf16_device_check(); } +static bool mkldnn_conv_enabled_fpmath_mode_tf32(){ + return at::globalContext().float32Precision("mkldnn", "conv") == "tf32" && + cpuinfo_has_x86_amx_fp16(); +} static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bool is_channels_last) { auto memory_format = at::MemoryFormat::Contiguous; @@ -271,6 +275,10 @@ static Tensor _mkldnn_convolution( input_t.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (mkldnn_conv_enabled_fpmath_mode_tf32() && + input_t.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } _mkldnn_convolution_out( input_t, weight_t, @@ -455,6 +463,9 @@ Tensor mkldnn_convolution_pointwise_binary( if (mkldnn_conv_enabled_fpmath_mode_bf16() && input_t.scalar_type() ==at::kFloat){ op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (mkldnn_conv_enabled_fpmath_mode_tf32() && input_t.scalar_type() ==at::kFloat){ + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } if (bias.defined()) { const ideep::tensor b = itensor_from_tensor(bias); @@ -597,6 +608,10 @@ Tensor& mkldnn_convolution_pointwise_binary_( input_t.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (mkldnn_conv_enabled_fpmath_mode_tf32() && + input_t.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } _mkldnn_convolution_out( input_t, weight_t, @@ -718,6 +733,9 @@ Tensor _mkldnn_convolution_transpose( if (mkldnn_conv_enabled_fpmath_mode_bf16() && input_t.scalar_type() ==at::kFloat){ op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (mkldnn_conv_enabled_fpmath_mode_tf32() && input_t.scalar_type() ==at::kFloat){ + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } if (bias.defined()) { const ideep::tensor b = itensor_from_tensor(bias, /*from_const_data_ptr*/true); @@ -808,6 +826,10 @@ Tensor mkldnn_convolution_backward_input( weight.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (mkldnn_conv_enabled_fpmath_mode_tf32() && + weight.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } ideep::convolution_backward_data::compute_v2( grad_y, w, @@ -828,6 +850,11 @@ Tensor mkldnn_convolution_backward_input( TORCH_WARN_ONCE( "Unexpected ideep version to support fpmath_mode_bf16, please update ideep version to align with pytorch main branch"); } + if (mkldnn_conv_enabled_fpmath_mode_tf32() && + weight.scalar_type() == at::kFloat) { + TORCH_WARN_ONCE( + "Unexpected ideep version to support fpmath_mode_tf32, please update ideep version to align with pytorch main branch"); + } #endif if (grad_output.is_mkldnn()) { @@ -858,6 +885,10 @@ std::tuple mkldnn_convolution_backward_weights( input.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (mkldnn_conv_enabled_fpmath_mode_tf32() && + input.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } if (bias_defined) { ideep::convolution_backward_weights::compute_v2( x, @@ -1011,6 +1042,10 @@ Tensor mkldnn_convolution_transpose_backward_input( weight.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (mkldnn_conv_enabled_fpmath_mode_tf32() && + weight.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } ideep::convolution_transpose_backward_data::compute_v3( grad_y, w, @@ -1053,6 +1088,10 @@ std::tuple mkldnn_convolution_transpose_backward_weights( input.scalar_type() == at::kFloat) { op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (mkldnn_conv_enabled_fpmath_mode_tf32() && + input.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } if (bias_defined) { ideep::convolution_transpose_backward_weights::compute_v3( x, diff --git a/aten/src/ATen/native/mkldnn/Linear.cpp b/aten/src/ATen/native/mkldnn/Linear.cpp index 8dbb29bb3e01..8f0b91b3e3f7 100644 --- a/aten/src/ATen/native/mkldnn/Linear.cpp +++ b/aten/src/ATen/native/mkldnn/Linear.cpp @@ -73,6 +73,11 @@ static bool use_mkldnn_bf32_linear() { mkldnn_bf16_device_check(); } +static bool use_mkldnn_tf32_linear() { + return at::globalContext().float32Precision("mkldnn", "matmul") == "tf32" && + cpuinfo_has_x86_amx_fp16(); +} + Tensor mkldnn_linear( const Tensor& self, const Tensor& weight_t, const std::optional& bias_opt) { @@ -259,6 +264,9 @@ Tensor mkldnn_linear_pointwise( if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){ op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (use_mkldnn_tf32_linear() && input_t.scalar_type() == at::kFloat){ + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } if (mkldnn_bias.has_value()) { ideep::inner_product_forward::compute( mkldnn_input, @@ -352,6 +360,10 @@ Tensor mkldnn_linear_pointwise_binary( op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); } + if (use_mkldnn_tf32_linear() && input_t.scalar_type() == at::kFloat){ + op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); + } + if (mkldnn_bias.has_value()) { ideep::inner_product_forward::compute_binary( mkldnn_input, diff --git a/aten/src/ATen/native/mkldnn/Matmul.cpp b/aten/src/ATen/native/mkldnn/Matmul.cpp index a9c094d85989..5a6e59fad786 100644 --- a/aten/src/ATen/native/mkldnn/Matmul.cpp +++ b/aten/src/ATen/native/mkldnn/Matmul.cpp @@ -1,7 +1,8 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include #include #include +#include +#include #include #if !AT_MKLDNN_ENABLED() @@ -53,7 +54,7 @@ bool mkldnn_fp16_gemm( c10::Half *c, int64_t ldc) { return false; } -bool mkldnn_bf32_gemm( +bool mkldnn_reduced_f32_gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, float alpha, @@ -85,6 +86,13 @@ void mkldnn_matmul_i8i8i32( TORCH_INTERNAL_ASSERT(false, __func__, ": ATen not compiled with MKLDNN support"); } +bool use_mkldnn_tf32_matmul( + const Tensor& mat1, + const Tensor& mat2, + const Tensor& result) { + return false; +} + } // namespace at::native @@ -107,6 +115,10 @@ static bool use_mkldnn_bf32_matmul() { return use_mkldnn_bf16_matmul() && at::globalContext().float32Precision("mkldnn", "matmul") == "bf16"; } +static bool use_mkldnn_tf32_matmul() { + return cpuinfo_has_x86_amx_fp16() && at::globalContext().float32Precision("mkldnn", "matmul") == "tf32"; +} + // returns an ideep::tensor // - dims: shape e.g: {M,N} // - idtype: ideep data type e.g: (f32, bf16, f16) @@ -144,7 +156,8 @@ mkldnn_gemm( bool bf16_usable = std::is_same_v && use_mkldnn_bf16_matmul(); bool fp16_usable = std::is_same_v && use_mkldnn_fp16_matmul(); bool bf32_usable = std::is_same_v && use_mkldnn_bf32_matmul(); - if ( !(bf16_usable || fp16_usable || bf32_usable) || + bool tf32_usable = std::is_same_v && use_mkldnn_tf32_matmul(); + if ( !(bf16_usable || fp16_usable || bf32_usable || tf32_usable) || (m * n * k <= 16 * 16 * 16) || (alpha == 0.0f)) { return false; } @@ -155,6 +168,7 @@ mkldnn_gemm( op_attr = ideep::attr_t::fuse_sum(); } if (bf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); // bf32 path + if (tf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); // tf32 path // NOTE: View as c-contiguous to avoid extra reordering in mkldnn // Use identity: C = AB <=> C^T = B^T A^T @@ -281,7 +295,7 @@ bool mkldnn_fp16_gemm( return mkldnn_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } -bool mkldnn_bf32_gemm( +bool mkldnn_reduced_f32_gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, float alpha, @@ -339,6 +353,7 @@ void mkldnn_matmul( auto mat2_unsqueezed = mat2.dim() == 1 ? mat2.unsqueeze(1) : mat2; auto result_unsqueezed = result.dim() == 1 ? result.unsqueeze(1) : result; bool bf32_usable = mat1.scalar_type() == at::kFloat && use_mkldnn_bf32_matmul(); + bool tf32_usable = mat1.scalar_type() == at::kFloat && use_mkldnn_tf32_matmul(); ideep::attr_t op_attr; // "addmm", "addbmm" "baddbmm" in pytorch allow bias to be 2-D or 3-D tensor @@ -346,6 +361,7 @@ void mkldnn_matmul( // to address their differences, we use mkldnn post ops to perform a fused "add" after matrix multiplication is over if (beta != 0.0f) op_attr = ideep::attr_t::fuse_sum(); if (bf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); // bf32 path + if (tf32_usable) op_attr.set_fpmath_mode(dnnl_fpmath_mode_tf32); // tf32 path // If alpha = 0, dose not need actually do gemm computation if (alpha == 0) return; @@ -412,70 +428,56 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){ } } -bool use_mkldnn_bf16_matmul( +template +bool use_mkldnn_typed_matmul( const Tensor& mat1, const Tensor& mat2, const Tensor& result) { + bool dtype_check = false; + if constexpr (std::is_same_v) { #if defined(__aarch64__) - if (mkldnn_bf16_device_check_arm()) { - //onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. Arm Neoverse V1 - //so, don't restrict the mkldnn_matmul only for bf16 inputs, allow it for float as well - return ( - use_mkldnn_bf16_matmul() && - (mat1.scalar_type() == mat2.scalar_type()) && (!result.defined() || (mat1.scalar_type() == result.scalar_type())) && - ((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)) && - mat1.numel() != 0 && - mat2.numel() != 0 && - checksize(mat1, mat2)); - } else + if (mkldnn_bf16_device_check_arm()) { + // onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. + // Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16 + // inputs, allow it for float as well + dtype_check = use_mkldnn_bf16_matmul() && + ((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)); + } +#else + dtype_check = dtype_check && use_mkldnn_bf16_matmul() && + (mat1.scalar_type() == kBFloat16); #endif - { - return ( - use_mkldnn_bf16_matmul() && - mat1.scalar_type() == kBFloat16 && - mat2.scalar_type() == kBFloat16 && - (!result.defined() || result.scalar_type() == kBFloat16) && - mat1.numel() != 0 && - mat2.numel() != 0 && - checksize(mat1, mat2)); + } else if constexpr (std::is_same_v) { + dtype_check = dtype_check && use_mkldnn_fp16_matmul() && + (mat1.scalar_type() == kHalf); + } else if constexpr (std::is_same_v) { + dtype_check = dtype_check && + (use_mkldnn_bf32_matmul() || use_mkldnn_tf32_matmul()) && + (mat1.scalar_type() == kFloat); } -} - -bool use_mkldnn_fp16_matmul( - const Tensor& mat1, - const Tensor& mat2, - const Tensor& result) { - - return ( - use_mkldnn_fp16_matmul() && - mat1.scalar_type() == kHalf && - mat2.scalar_type() == kHalf && - (!result.defined() || result.scalar_type() == kHalf) && - mat1.numel() != 0 && - mat2.numel() != 0 && - checksize(mat1, mat2)); -} - -bool use_mkldnn_bf32_matmul( - const Tensor& mat1, - const Tensor& mat2, - const Tensor& result) { - - return ( - use_mkldnn_bf32_matmul() && - mat1.scalar_type() == kFloat && - mat2.scalar_type() == kFloat && - (!result.defined() || result.scalar_type() == kFloat) && - mat1.numel() != 0 && - mat2.numel() != 0 && - checksize(mat1, mat2)); + if (!dtype_check) { + return false; + } + bool size_check = + mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2); + dtype_check = (mat1.scalar_type() == mat2.scalar_type()) && + (!result.defined() || result.scalar_type() == mat1.scalar_type()); + return dtype_check && size_check; } bool use_mkldnn_matmul( const Tensor& mat1, const Tensor& mat2, const Tensor& result) { - return (use_mkldnn_bf16_matmul(mat1, mat2, result) || use_mkldnn_fp16_matmul(mat1, mat2, result) || use_mkldnn_bf32_matmul(mat1, mat2, result)); + auto mat1_type = mat1.scalar_type(); + if (mat1_type != kBFloat16 || mat1_type != kHalf || mat1_type != kFloat) { + return false; + } + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, kHalf, mat1.scalar_type(), "use_mkldnn_matmul", [&] { + return use_mkldnn_typed_matmul(mat1, mat2, result); + }); + return false; } static void _mkldnn_matmul_i8i8i32_with_primitive( diff --git a/aten/src/ATen/native/mkldnn/Matmul.h b/aten/src/ATen/native/mkldnn/Matmul.h index e783d2372403..80247497d58f 100644 --- a/aten/src/ATen/native/mkldnn/Matmul.h +++ b/aten/src/ATen/native/mkldnn/Matmul.h @@ -29,6 +29,11 @@ bool use_mkldnn_bf32_matmul( const Tensor& mat2, const Tensor& result_opt); +bool use_mkldnn_tf32_matmul( + const Tensor& mat1, + const Tensor& mat2, + const Tensor& result_opt); + // Try running mkldnn optimized gemm, or returns false if naive gemm would be faster bool mkldnn_bf16_gemm( TransposeType transa, TransposeType transb, @@ -62,7 +67,7 @@ oneDNN implicit reduced precision arithmetic feature https://github.com/mgouicem/oneDNN/tree/mgouicem/rfcs/implicit_downconvert/rfcs/20210301-computation-datatype to allow implicitly cast data type from FP32 to BF16 in onednn compute primitives */ -bool mkldnn_bf32_gemm( +bool mkldnn_reduced_f32_gemm( TransposeType transa, TransposeType transb, int64_t m, int64_t n, int64_t k, float alpha, diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 976b62c7ac4b..e6f87f5499a4 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -145,8 +145,6 @@ MPSGraphTensorData* getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStre MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar); MPSGraph* make_mps_graph(); -void printTensorNDArray(const TensorBase& t); -MPSNDArray* ndArrayFromTensor(const TensorBase& tensor, MPSShape* shape, MPSDataType mpsType); MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType); MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType, MPSShape* mpsShape); diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 142186b748b1..29d07a3f7aa3 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -327,7 +327,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { if (exclude_shape) { fmt::format_to(buf_iterator, "-1"); } else { - fmt::format_to(buf_iterator, getArrayRefString(tensor.sizes())); + fmt::format_to(buf_iterator, "{}", getArrayRefString(tensor.sizes())); } } fmt::format_to(buf_iterator, "]"); @@ -377,36 +377,6 @@ Tensor getTensorView(const Tensor& t, MPSShape* shape) { return [NSArray arrayWithObjects:numbers.data() count:numbers.size()]; } -void printTensorNDArray(const TensorBase& t) { - if (!t.is_mps()) - return; - if (t.numel() == 0) - return; - // Get shape and data type - auto selfShape = getMPSShape(t); - auto selfDType = getMPSDataType(t.scalar_type()); - - // Initialize data - id selfBuf = getMTLBufferStorage(t); - MPSGraphTensorData* tdata = [[[MPSGraphTensorData alloc] initWithMTLBuffer:selfBuf shape:selfShape - dataType:selfDType] autorelease]; - C10_CLANG_DIAGNOSTIC_PUSH() -#if C10_CLANG_HAS_WARNING("-Wobjc-method-access") - C10_CLANG_DIAGNOSTIC_IGNORE("-Wobjc-method-access") -#endif - [tdata printNDArray]; - C10_CLANG_DIAGNOSTIC_POP() -} - -MPSNDArray* ndArrayFromTensor(const TensorBase& tensor, MPSShape* shape, MPSDataType mpsType) { - id buffer = getMTLBufferStorage(tensor); - MPSGraphTensorData* tmpGraphTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buffer - shape:shape - dataType:mpsType] autorelease]; - - return [tmpGraphTensorData mpsndarray]; -} - static std::vector getSortedStrides(const IntArrayRef& s) { std::vector idx(s.size()); iota(idx.begin(), idx.end(), 0); @@ -457,12 +427,22 @@ void printTensorNDArray(const TensorBase& t) { return result; } +// Should be called before initWithBuffer to prevent hard crashes with +// '[MPSNDArray initWithDevice:descriptor:isTextureBacked:] Error: NDArray dimension length > INT_MAX' +static void check_mps_shape(MPSShape* shape) { + for (NSNumber* elem in shape) { + const auto val = [elem longValue]; + TORCH_CHECK(val <= std::numeric_limits::max(), "MPSGaph does not support tensor dims larger than INT_MAX"); + } +} + MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes, MPSShape* strides) { id srcBuf = getMTLBufferStorage(t); MPSDataType mpsDataType = getMPSDataType(t.scalar_type()); MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:mpsDataType shape:sizes]; srcTensorDesc.preferPackedRows = YES; + check_mps_shape(sizes); MPSNDArray* srcNDArray = [[[MPSNDArray alloc] initWithBuffer:srcBuf offset:t.storage_offset() * t.element_size() descriptor:srcTensorDesc] autorelease]; @@ -572,9 +552,9 @@ void printTensorNDArray(const TensorBase& t) { // Tensor is contiguous and has no storage offset. // Wrap it directly inside MPSGraphTensorData if ((_tensor.is_contiguous() && !_tensor.storage_offset()) || !useMPSStridedAPI || !is_macOS_15_0_or_newer) { - _value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf - shape:mpsShape_ ? mpsShape_ : getMPSShape(_tensor) - dataType:dataType] autorelease]; + auto shape = mpsShape_ ? mpsShape_ : getMPSShape(_tensor); + check_mps_shape(shape); + _value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf shape:shape dataType:dataType] autorelease]; } else { IntArrayRef view_shape; if (mpsShape_) { @@ -583,8 +563,11 @@ void printTensorNDArray(const TensorBase& t) { MPSShape* mpsShape = getMPSShape(_tensor); MPSShape* mpsStrides = getMPSShape(_tensor.strides()); + check_mps_shape(mpsShape); auto storage_numel = src.storage().nbytes() / src.element_size(); + TORCH_CHECK(storage_numel <= std::numeric_limits::max(), + "MPSGaph does not support tensor dims larger than INT_MAX"); MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:dataType shape:@[ @(storage_numel) ]]; srcTensorDesc.preferPackedRows = YES; diff --git a/aten/src/ATen/native/mps/kernels/Pooling.h b/aten/src/ATen/native/mps/kernels/Pooling.h index 1d366f9620db..d72131bd4087 100644 --- a/aten/src/ATen/native/mps/kernels/Pooling.h +++ b/aten/src/ATen/native/mps/kernels/Pooling.h @@ -5,29 +5,30 @@ // maximum allowed pooling dimensions is N-2, because the input may have up to 2 // leading dimensions that are not pooled. To support up to 3-D pooling, N=5 is // the default. -template +template struct PoolingParams { int32_t dims; int32_t pooling_dims; - ::c10::metal::array input_sizes; - ::c10::metal::array input_strides; - ::c10::metal::array output_sizes; - ::c10::metal::array output_strides; - ::c10::metal::array indices_sizes; - ::c10::metal::array indices_strides; - ::c10::metal::array kernel_size; - ::c10::metal::array stride; - ::c10::metal::array padding; - ::c10::metal::array dilation; + ::c10::metal::array input_sizes; + ::c10::metal::array input_strides; + ::c10::metal::array output_sizes; + ::c10::metal::array output_strides; + ::c10::metal::array indices_sizes; + ::c10::metal::array indices_strides; + ::c10::metal::array kernel_size; + ::c10::metal::array stride; + ::c10::metal::array padding; + ::c10::metal::array dilation; + bool return_indices; }; -template +template struct PoolingBackwardParams { int32_t dims; int32_t pooling_dims; - ::c10::metal::array grad_input_sizes; - ::c10::metal::array grad_input_strides; - ::c10::metal::array grad_output_sizes; - ::c10::metal::array grad_output_strides; - ::c10::metal::array indices_strides; + ::c10::metal::array grad_input_sizes; + ::c10::metal::array grad_input_strides; + ::c10::metal::array grad_output_sizes; + ::c10::metal::array grad_output_strides; + ::c10::metal::array indices_strides; }; diff --git a/aten/src/ATen/native/mps/kernels/Pooling.metal b/aten/src/ATen/native/mps/kernels/Pooling.metal index 92a22c97f017..05ce39bd8316 100644 --- a/aten/src/ATen/native/mps/kernels/Pooling.metal +++ b/aten/src/ATen/native/mps/kernels/Pooling.metal @@ -6,6 +6,28 @@ using namespace metal; using namespace c10::metal; +template +struct IterBounds { + T start; + T end; +}; + +template +IterBounds get_input_iter_bounds( + constant int32_t* input_sizes, + thread int32_t (&pooling_dim_indices)[3], + constant int32_t* kernel_size, + constant int32_t* stride, + constant int32_t* padding, + constant int32_t* dilation) { + auto d = dilation[dim]; + auto start = stride[dim] * pooling_dim_indices[dim] - padding[dim]; + auto end = min(start + kernel_size[dim] * d, input_sizes[dim]); + auto start_correction = d * ((-start - 1 + d) / d); + start += start < 0 ? start_correction : 0; + return IterBounds{start, end}; +} + // Iterates through all the input elements that this kernel needs to // apply max to. Specialized for 3 pooling dimensions. // TODO: Support any number of pooling dims @@ -14,82 +36,62 @@ void max_pool_3d_input_iter( constant T* input, device T* output, device int64_t* indices, - constant int64_t* input_sizes, - constant int64_t* input_strides, - device int64_t* work_pooling_dim_indices, - constant int64_t* kernel_size, - constant int64_t* stride, - constant int64_t* padding, - constant int64_t* dilation) { - int64_t o0 = work_pooling_dim_indices[0]; - int64_t o1 = work_pooling_dim_indices[1]; - int64_t o2 = work_pooling_dim_indices[2]; - - int64_t k0 = kernel_size[0]; - int64_t k1 = kernel_size[1]; - int64_t k2 = kernel_size[2]; - - int64_t s0 = stride[0]; - int64_t s1 = stride[1]; - int64_t s2 = stride[2]; - - int64_t d0 = dilation[0]; - int64_t d1 = dilation[1]; - int64_t d2 = dilation[2]; - - T max_value = 0; - int64_t max_index = -1; - - int64_t size12 = input_sizes[1] * input_sizes[2]; - - for (int64_t i0 = (s0 * o0) - padding[0]; - i0 < (s0 * o0 - padding[0] + k0 * d0) && i0 < input_sizes[0]; - i0 += d0) { - if (i0 < 0) { - continue; - } - int64_t offset0 = input_strides[0] * i0; - - for (int64_t i1 = (s1 * o1) - padding[1]; - i1 < (s1 * o1 - padding[1] + k1 * d1) && i1 < input_sizes[1]; - i1 += d1) { - if (i1 < 0) { - continue; - } - int64_t offset1 = input_strides[1] * i1; - - for (int64_t i2 = (s2 * o2) - padding[2]; - i2 < (s2 * o2 - padding[2] + k2 * d2) && i2 < input_sizes[2]; - i2 += d2) { - if (i2 < 0) { - continue; + constant int32_t* input_sizes, + constant int32_t* input_strides, + thread int32_t (&pooling_dim_indices)[3], + constant int32_t* kernel_size, + constant int32_t* stride, + constant int32_t* padding, + constant int32_t* dilation, + bool return_indices) { + auto bounds0 = get_input_iter_bounds<0>( + input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation); + auto bounds1 = get_input_iter_bounds<1>( + input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation); + auto bounds2 = get_input_iter_bounds<2>( + input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation); + + auto d0 = dilation[0]; + auto d1 = dilation[1]; + auto d2 = dilation[2]; + + T max_value = input + [input_strides[0] * bounds0.start + input_strides[1] * bounds1.start + + input_strides[2] * bounds2.start]; + auto size12 = input_sizes[1] * input_sizes[2]; + auto max_index = + bounds0.start * size12 + bounds1.start * input_sizes[2] + bounds2.start; + + for (auto i0 = bounds0.start; i0 < bounds0.end; i0 += d0) { + auto offset0 = input_strides[0] * i0; + + for (auto i1 = bounds1.start; i1 < bounds1.end; i1 += d1) { + auto offset1 = input_strides[1] * i1; + + for (auto i2 = bounds2.start; i2 < bounds2.end; i2 += d2) { + auto offset2 = input_strides[2] * i2; + auto input_value = input[offset0 + offset1 + offset2]; + bool is_greater = input_value > max_value; + + max_value = is_greater ? input_value : max_value; + + if (return_indices) { + auto input_index = i0 * size12 + i1 * input_sizes[2] + i2; + max_index = is_greater ? input_index : max_index; } - int64_t offset2 = input_strides[2] * i2; - - const T input_value = input[offset0 + offset1 + offset2]; - int64_t input_index = i0 * size12 + i1 * input_sizes[2] + i2; - - T new_max_value = (max_index == -1 || input_value > max_value) - ? input_value - : max_value; - int64_t new_max_index = (max_index == -1 || input_value > max_value) - ? input_index - : max_index; - - max_value = new_max_value; - max_index = new_max_index; } } } - *output = max_value; - *indices = max_index; + if (return_indices) { + *indices = max_index; + } } struct PoolOffsets { - int64_t output; - int64_t indices; - int64_t input_leading; + int32_t output; + int32_t indices; + int32_t input_leading; PoolOffsets() : output(0), indices(0), input_leading(0) {} }; @@ -98,30 +100,35 @@ struct PoolOffsets { // calculate, `output[N, C, d, h, w]`. Also, find the offset of the input for // the leading dim indices, `input[N, C]`. Optionally, keep track of the output // pooling dimension indices, `[d, h , w]`. -PoolOffsets find_pool_offsets( - constant int64_t* output_sizes, - constant int64_t* output_strides, - constant int64_t* indices_strides, - constant int64_t* input_strides, - device int64_t* work_pooling_dim_indices, - int32_t dims, +// NOTE: This is templated per number of dimensions so that the compiler can +// unroll the loop, giving better performance. +template +PoolOffsets find_pool_offsets_dim_specific( + constant int32_t* output_sizes, + constant int32_t* output_strides, + constant int32_t* indices_strides, + constant int32_t* input_strides, + int32_t pooling_dim_indices[3], int32_t leading_dims, + bool return_indices, uint tid) { - int64_t output_idx = static_cast(tid); + auto output_idx = static_cast(tid); PoolOffsets offsets; - for (int64_t dim = dims - 1; dim >= 0; dim--) { - int64_t dim_idx = output_idx % (output_sizes[dim]); + for (auto dim = dims - 1; dim >= 0; dim--) { + auto dim_idx = output_idx % (output_sizes[dim]); offsets.output += output_strides[dim] * dim_idx; - offsets.indices += indices_strides[dim] * dim_idx; + if (return_indices) { + offsets.indices += indices_strides[dim] * dim_idx; + } if (dim < leading_dims) { offsets.input_leading += input_strides[dim] * dim_idx; } else { // Keep track of pooling dimension indices of the output element, so we // can use them in the input iteration later on. - if (work_pooling_dim_indices != nullptr) { - work_pooling_dim_indices[dim - leading_dims] = dim_idx; + if (pooling_dim_indices != nullptr) { + pooling_dim_indices[dim - leading_dims] = dim_idx; } } output_idx = output_idx / output_sizes[dim]; @@ -130,45 +137,77 @@ PoolOffsets find_pool_offsets( return offsets; } +PoolOffsets find_pool_offsets( + constant int32_t* output_sizes, + constant int32_t* output_strides, + constant int32_t* indices_strides, + constant int32_t* input_strides, + int32_t pooling_dim_indices[3], + int32_t dims, + int32_t leading_dims, + bool return_indices, + uint tid) { + switch (dims) { + case 5: + return find_pool_offsets_dim_specific<5>( + output_sizes, + output_strides, + indices_strides, + input_strides, + pooling_dim_indices, + leading_dims, + return_indices, + tid); + case 4: + return find_pool_offsets_dim_specific<4>( + output_sizes, + output_strides, + indices_strides, + input_strides, + pooling_dim_indices, + leading_dims, + return_indices, + tid); + } + return PoolOffsets(); +} + // Kernel computes one element of the output per kernel call. template kernel void max_pool( - constant void* input_ [[buffer(0)]], - device void* output_ [[buffer(1)]], - device void* indices_ [[buffer(2)]], - device int64_t* work_pooling_dim_indices_ [[buffer(3)]], - constant PoolingParams<5>& params [[buffer(4)]], + constant T* input [[buffer(0)]], + device T* output [[buffer(1)]], + device int64_t* indices [[buffer(2)]], + constant PoolingParams<5>& params [[buffer(3)]], uint tid [[thread_position_in_grid]]) { - int32_t pooling_dims = params.pooling_dims; - int32_t dims = params.dims; - constant int64_t* input_sizes = params.input_sizes.data(); - constant int64_t* input_strides = params.input_strides.data(); - constant int64_t* output_sizes = params.output_sizes.data(); - constant int64_t* output_strides = params.output_strides.data(); - constant int64_t* indices_strides = params.indices_strides.data(); - constant int64_t* kernel_size = params.kernel_size.data(); - constant int64_t* stride = params.stride.data(); - constant int64_t* padding = params.padding.data(); - constant int64_t* dilation = params.dilation.data(); - - int32_t leading_dims = dims - pooling_dims; - constant T* input = reinterpret_cast(input_); - device T* output = reinterpret_cast(output_); - device int64_t* indices = reinterpret_cast(indices_); + bool return_indices = params.return_indices; + auto pooling_dims = params.pooling_dims; + auto dims = params.dims; + auto input_sizes = params.input_sizes.data(); + auto input_strides = params.input_strides.data(); + auto output_sizes = params.output_sizes.data(); + auto output_strides = params.output_strides.data(); + auto indices_strides = params.indices_strides.data(); + auto kernel_size = params.kernel_size.data(); + auto stride = params.stride.data(); + auto padding = params.padding.data(); + auto dilation = params.dilation.data(); + + auto leading_dims = dims - pooling_dims; // This buffer keeps track of the pooling dimension indices of this thread's // element of the output. We need to fill it with the proper values below. - device int64_t* work_pooling_dim_indices = - work_pooling_dim_indices_ + tid * pooling_dims; + int32_t pooling_dim_indices[3]; PoolOffsets offsets = find_pool_offsets( output_sizes, output_strides, indices_strides, input_strides, - work_pooling_dim_indices, + pooling_dim_indices, dims, leading_dims, + return_indices, tid); output += offsets.output; @@ -181,11 +220,12 @@ kernel void max_pool( indices, input_sizes + leading_dims, input_strides + leading_dims, - work_pooling_dim_indices, + pooling_dim_indices, kernel_size, stride, padding, - dilation); + dilation, + return_indices); } // Finds the element in the grad input which corresponds to the index into the @@ -195,15 +235,15 @@ void max_pool_backward_impl( device AtomicType_t* grad_input, T grad_output_element, int32_t input_index, - constant int64_t* grad_input_sizes, - constant int64_t* grad_input_strides, + constant int32_t* grad_input_sizes, + constant int32_t* grad_input_strides, int32_t grad_input_leading_offset, int32_t pooling_dims) { int32_t size_prod = 1; int32_t pool_offset = 0; - for (int32_t dim = pooling_dims - 1; dim >= 0; dim--) { - int32_t next_size_prod = grad_input_sizes[dim] * size_prod; + for (auto dim = pooling_dims - 1; dim >= 0; dim--) { + auto next_size_prod = grad_input_sizes[dim] * size_prod; pool_offset += grad_input_strides[dim] * ((input_index % next_size_prod) / size_prod); size_prod *= grad_input_sizes[dim]; @@ -221,15 +261,15 @@ kernel void max_pool_backward( constant int64_t* indices [[buffer(2)]], constant PoolingBackwardParams<5>& params [[buffer(3)]], uint tid [[thread_position_in_grid]]) { - int32_t pooling_dims = params.pooling_dims; - int32_t dims = params.dims; - constant int64_t* grad_input_sizes = params.grad_input_sizes.data(); - constant int64_t* grad_input_strides = params.grad_input_strides.data(); - constant int64_t* grad_output_sizes = params.grad_output_sizes.data(); - constant int64_t* grad_output_strides = params.grad_output_strides.data(); - constant int64_t* indices_strides = params.indices_strides.data(); + auto pooling_dims = params.pooling_dims; + auto dims = params.dims; + auto grad_input_sizes = params.grad_input_sizes.data(); + auto grad_input_strides = params.grad_input_strides.data(); + auto grad_output_sizes = params.grad_output_sizes.data(); + auto grad_output_strides = params.grad_output_strides.data(); + auto indices_strides = params.indices_strides.data(); - int32_t leading_dims = dims - pooling_dims; + auto leading_dims = dims - pooling_dims; PoolOffsets offsets = find_pool_offsets( grad_output_sizes, @@ -239,6 +279,7 @@ kernel void max_pool_backward( nullptr, dims, leading_dims, + /*return_indices=*/true, tid); max_pool_backward_impl( @@ -253,11 +294,10 @@ kernel void max_pool_backward( #define REGISTER_MAX_POOL_OP(DTYPE) \ template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool( \ - constant void* input_ [[buffer(0)]], \ - device void* output_ [[buffer(1)]], \ - device void* indices_ [[buffer(2)]], \ - device int64_t* work_pooling_dim_indices_ [[buffer(3)]], \ - constant PoolingParams<5>& params [[buffer(4)]], \ + constant DTYPE * input [[buffer(0)]], \ + device DTYPE * output [[buffer(1)]], \ + device int64_t* indices [[buffer(2)]], \ + constant PoolingParams<5>& params [[buffer(3)]], \ uint tid [[thread_position_in_grid]]); #define REGISTER_MAX_POOL_BACKWARD_OP(DTYPE) \ diff --git a/aten/src/ATen/native/mps/operations/Attention.mm b/aten/src/ATen/native/mps/operations/Attention.mm index 3f3c6b309fd6..69ec9af055ba 100644 --- a/aten/src/ATen/native/mps/operations/Attention.mm +++ b/aten/src/ATen/native/mps/operations/Attention.mm @@ -114,8 +114,22 @@ graph->maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, *attn_mask); maskedMM = [mpsGraph additionWithPrimaryTensor:maskedMM secondaryTensor:graph->maskTensor name:nil]; } + + // Account for case where all values were masked causing division by 0 in softmax (issue:#156707) + // Overwrites expected NANs in sm with zeros. + auto negInfTensor = [mpsGraph constantWithScalar:-INFINITY shape:maskedMM.shape dataType:maskedMM.dataType]; + auto elem_neg_inf = [mpsGraph equalWithPrimaryTensor:maskedMM secondaryTensor:negInfTensor name:nil]; + auto all_neg_infs_along_axis = [mpsGraph reductionAndWithTensor:elem_neg_inf axis:3 name:nil]; + auto zero_mask = [mpsGraph broadcastTensor:all_neg_infs_along_axis toShape:maskedMM.shape name:nil]; + auto zeroTensor = [mpsGraph constantWithScalar:0.0 shape:maskedMM.shape dataType:maskedMM.dataType]; + auto sm = [mpsGraph softMaxWithTensor:maskedMM axis:3 name:nil]; - auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:sm secondaryTensor:vTensor name:nil]; + MPSGraphTensor* correctedSM = [mpsGraph selectWithPredicateTensor:zero_mask + truePredicateTensor:zeroTensor + falsePredicateTensor:sm + name:nil]; + + auto output = [mpsGraph matrixMultiplicationWithPrimaryTensor:correctedSM secondaryTensor:vTensor name:nil]; graph->qTensor = qTensor; graph->kTensor = kTensor; graph->vTensor = vTensor; diff --git a/aten/src/ATen/native/mps/operations/ConstantOps.mm b/aten/src/ATen/native/mps/operations/ConstantOps.mm index 644cb80c1e44..e36ac4dc4524 100644 --- a/aten/src/ATen/native/mps/operations/ConstantOps.mm +++ b/aten/src/ATen/native/mps/operations/ConstantOps.mm @@ -62,15 +62,12 @@ return self; } -// returns false if tensor cannot be filled with fillBuffer() -static bool fill_mps_tensor_(Tensor& self, uint8_t value) { - if (self.is_contiguous()) { - MPSStream* stream = getCurrentMPSStream(); - auto storage_byte_offset = self.storage_offset() * self.itemsize(); - stream->fill(mps::getMTLBufferStorage(self), value, self.nbytes(), storage_byte_offset); - return true; - } - return false; +static Tensor& fill_mps_tensor_(Tensor& self, uint8_t value) { + TORCH_INTERNAL_ASSERT(self.is_contiguous()); + const auto stream = getCurrentMPSStream(); + auto storage_byte_offset = self.storage_offset() * self.itemsize(); + stream->fill(mps::getMTLBufferStorage(self), value, self.nbytes(), storage_byte_offset); + return self; } Tensor& fill_scalar_mps(Tensor& self, const Scalar& value) { @@ -89,8 +86,20 @@ static bool fill_mps_tensor_(Tensor& self, uint8_t value) { return self; } // check if it's possible to use fillBuffer() to fill the Tensor's storage - if (value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true) - return self; + if (self.is_contiguous()) { + if (value.toDouble() == 0.0) { + return fill_mps_tensor_(self, 0); + } + if (self.scalar_type() == kBool) { + return fill_mps_tensor_(self, value.toBool()); + } + if (self.scalar_type() == kByte) { + return fill_mps_tensor_(self, value.toByte()); + } + if (self.scalar_type() == kChar) { + return fill_mps_tensor_(self, value.toChar()); + } + } return fill_scalar_mps_impl(self, value); } @@ -101,8 +110,6 @@ static bool fill_mps_tensor_(Tensor& self, uint8_t value) { value.dim(), " dimensions."); Scalar scalar_value = value.item(); - if (scalar_value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true) - return self; return fill_scalar_mps(self, scalar_value); } diff --git a/aten/src/ATen/native/mps/operations/Pooling.mm b/aten/src/ATen/native/mps/operations/Pooling.mm index efc77360bb99..d5e6500194f8 100644 --- a/aten/src/ATen/native/mps/operations/Pooling.mm +++ b/aten/src/ATen/native/mps/operations/Pooling.mm @@ -252,22 +252,20 @@ static void pool2d_template(const Tensor& input, } } -static std::vector copy_and_maybe_expand(IntArrayRef a, int32_t pooling_dims) { - std::vector b; - if (a.size() == 1) { - b.assign(pooling_dims, a[0]); - } else { - b.assign(a.data(), a.data() + pooling_dims); +static std::vector copy_and_maybe_expand(IntArrayRef a, int32_t pooling_dims) { + std::vector b(pooling_dims); + for (const auto dim : c10::irange(pooling_dims)) { + b[dim] = safe_downcast(a[a.size() == 1 ? 0 : dim]); } return b; } using PoolSizes = std::tuple, - std::vector, - std::vector, - std::vector, - std::vector>; + std::vector, + std::vector, + std::vector, + std::vector>; static PoolSizes process_pool_sizes(const Tensor& input, IntArrayRef kernel_size, @@ -368,7 +366,7 @@ static PoolSizes process_pool_sizes(const Tensor& input, } static void max_pool_with_indices_out_mps_template(const Tensor& output, - const Tensor& indices, + const std::optional& indices_opt, const Tensor& input, IntArrayRef _kernel_size, IntArrayRef _stride, @@ -379,10 +377,14 @@ static void max_pool_with_indices_out_mps_template(const Tensor& output, const std::string& op_name) { auto [dims, output_size, kernel_size, stride, padding, dilation] = process_pool_sizes(input, _kernel_size, _stride, _padding, _dilation, ceil_mode, pooling_dims, op_name); + const Tensor& indices = *(at::borrow_from_optional_tensor(indices_opt)); + const bool return_indices = indices.defined(); const auto memory_format = input.suggest_memory_format(); output.resize_(output_size, memory_format); - indices.resize_(output_size, memory_format); + if (return_indices) { + indices.resize_(output_size, memory_format); + } auto iter = TensorIteratorConfig().add_output(output).resize_outputs(false).check_all_same_dtype(false).build(); @@ -395,33 +397,33 @@ static void max_pool_with_indices_out_mps_template(const Tensor& output, params.dims = dims; params.pooling_dims = pooling_dims; - memcpy(params.input_sizes.data(), input.sizes().data(), dims * sizeof(int64_t)); - memcpy(params.input_strides.data(), input.strides().data(), dims * sizeof(int64_t)); - memcpy(params.output_strides.data(), output.strides().data(), dims * sizeof(int64_t)); - memcpy(params.output_sizes.data(), output.sizes().data(), dims * sizeof(int64_t)); - memcpy(params.indices_strides.data(), indices.strides().data(), dims * sizeof(int64_t)); - memcpy(params.indices_sizes.data(), indices.sizes().data(), dims * sizeof(int64_t)); - memcpy(params.kernel_size.data(), kernel_size.data(), pooling_dims * sizeof(int64_t)); - memcpy(params.stride.data(), stride.data(), pooling_dims * sizeof(int64_t)); - memcpy(params.padding.data(), padding.data(), pooling_dims * sizeof(int64_t)); - memcpy(params.dilation.data(), dilation.data(), pooling_dims * sizeof(int64_t)); + params.return_indices = return_indices; + + for (const auto dim : c10::irange(dims)) { + params.input_sizes[dim] = safe_downcast(input.size(dim)); + params.input_strides[dim] = safe_downcast(input.stride(dim)); + params.output_sizes[dim] = safe_downcast(output.size(dim)); + params.output_strides[dim] = safe_downcast(output.stride(dim)); + if (return_indices) { + params.indices_sizes[dim] = safe_downcast(indices.size(dim)); + params.indices_strides[dim] = safe_downcast(indices.stride(dim)); + } + } + + memcpy(params.kernel_size.data(), kernel_size.data(), pooling_dims * sizeof(int32_t)); + memcpy(params.stride.data(), stride.data(), pooling_dims * sizeof(int32_t)); + memcpy(params.padding.data(), padding.data(), pooling_dims * sizeof(int32_t)); + memcpy(params.dilation.data(), dilation.data(), pooling_dims * sizeof(int32_t)); dispatch_sync_with_rethrow(mpsStream->queue(), ^() { @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); auto maxPoolPSO = lib.getPipelineStateForFunc("max_pool_" + scalarToMetalTypeString(input)); - // Each thread needs to keep track of the indices into the pooling - // dimensions for the element of the output that it calculates. In other - // words, if the thread calculates `output[N, C, d, h, w]` for a 3D pool, - // the kernel needs to keep track of the indices `[d, h, w]`. So we create - // a device-side buffer for the threads to store these indices. - id work_pooling_dim_indices = [[device newBufferWithLength:numThreads * pooling_dims * sizeof(int64_t) - options:0] autorelease]; - getMPSProfiler().beginProfileKernel(maxPoolPSO, op_name, {input}); [computeEncoder setComputePipelineState:maxPoolPSO]; - mtl_setArgs(computeEncoder, input, output, indices, work_pooling_dim_indices, params); + mtl_setArgs( + computeEncoder, input, output, return_indices ? std::optional(indices) : std::nullopt, params); mtl_dispatch1DJob(computeEncoder, maxPoolPSO, numThreads); getMPSProfiler().endProfileKernel(maxPoolPSO); @@ -454,11 +456,14 @@ static void max_pool_with_indices_backward_out_mps_template(Tensor& grad_input, params.dims = dims; params.pooling_dims = pooling_dims; - memcpy(params.grad_input_sizes.data(), grad_input.sizes().data(), dims * sizeof(int64_t)); - memcpy(params.grad_input_strides.data(), grad_input.strides().data(), dims * sizeof(int64_t)); - memcpy(params.grad_output_strides.data(), grad_output.strides().data(), dims * sizeof(int64_t)); - memcpy(params.grad_output_sizes.data(), grad_output.sizes().data(), dims * sizeof(int64_t)); - memcpy(params.indices_strides.data(), indices.strides().data(), dims * sizeof(int64_t)); + + for (const auto dim : c10::irange(dims)) { + params.grad_input_sizes[dim] = safe_downcast(grad_input.size(dim)); + params.grad_input_strides[dim] = safe_downcast(grad_input.stride(dim)); + params.grad_output_sizes[dim] = safe_downcast(grad_output.size(dim)); + params.grad_output_strides[dim] = safe_downcast(grad_output.stride(dim)); + params.indices_strides[dim] = safe_downcast(indices.stride(dim)); + } dispatch_sync_with_rethrow(mpsStream->queue(), ^() { @autoreleasepool { diff --git a/aten/src/ATen/native/mps/operations/RMSNorm.mm b/aten/src/ATen/native/mps/operations/RMSNorm.mm index 71128297d5bf..7948b5acd8e9 100644 --- a/aten/src/ATen/native/mps/operations/RMSNorm.mm +++ b/aten/src/ATen/native/mps/operations/RMSNorm.mm @@ -19,7 +19,14 @@ #include #endif -Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, const Tensor& weight, const double eps) { +std::tuple _fused_rms_norm_mps(const Tensor& input, + IntArrayRef normalized_shape, + const std::optional& weight_opt, + const std::optional eps) { + const Tensor weight = weight_opt.value().contiguous(); + const int64_t normalized_ndim = normalized_shape.size(); + auto eps_val = eps.value_or(std::numeric_limits::epsilon()); + TORCH_CHECK(input.is_contiguous() && weight.is_contiguous(), "Expected contiguous input and weight tensors"); auto output = at::empty_like(input); const auto input_shape = input.sizes(); @@ -41,7 +48,7 @@ Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, c const std::string kernel = fmt::format("{}_{}", name, scalarToMetalTypeString(output)); id rms_norm_pso = lib.getPipelineStateForFunc(kernel); [computeEncoder setComputePipelineState:rms_norm_pso]; - mtl_setArgs(computeEncoder, input, weight, output, eps, N, 1); + mtl_setArgs(computeEncoder, input, weight, output, eps_val, N, 1); const auto maxThreadsPerGroup = static_cast([rms_norm_pso maxTotalThreadsPerThreadgroup]); size_t threadgroup_size = maxThreadsPerGroup; @@ -58,7 +65,7 @@ Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, c } }); - return output; + return std::make_tuple(output, Tensor()); } } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 09ea127555f9..f3f3e0d582e5 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1067,6 +1067,7 @@ CUDA: baddbmm_out_cuda MPS: baddbmm_out_mps XPU: baddbmm_out_xpu + MTIA: baddbmm_out_mtia SparseCsrCUDA: baddbmm_out_sparse_csr_cuda - func: baddbmm.dtype(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1) -> Tensor @@ -1376,6 +1377,7 @@ CUDA: bmm_out_cuda MPS: bmm_out_mps XPU: bmm_out_xpu + MTIA: bmm_out_mtia SparseCPU: bmm_out_sparse_cpu SparseCUDA: bmm_out_sparse_cuda SparseCsrCUDA: bmm_out_sparse_csr_cuda @@ -3314,9 +3316,15 @@ dispatch: CompositeImplicitAutograd: rms_norm_symint -- func: _fused_rms_norm(Tensor input, int normalized_shape_ndim, Tensor weight, float eps) -> Tensor +- func: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor) dispatch: + CUDA: _fused_rms_norm_cuda MPS: _fused_rms_norm_mps + CompositeImplicitAutograd: rms_norm_composite + +- func: _fused_rms_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor rstd, Tensor? weight, bool[2] output_mask) -> (Tensor, Tensor) + dispatch: + CUDA: _fused_rms_norm_backward_cuda - func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor variants: function, method @@ -3432,7 +3440,7 @@ - func: _wrapped_quantized_linear_prepacked(Tensor input, Tensor input_scale, Tensor input_zero_point, Tensor packed_weight, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor -- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor +- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor? bias) -> Tensor - func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor @@ -7059,6 +7067,7 @@ CUDA: addmm_out_cuda MPS: addmm_out_mps XPU: addmm_out_xpu + MTIA: addmm_out_mtia SparseCPU: addmm_out_sparse_dense_cpu SparseCUDA: addmm_out_sparse_dense_cuda SparseCsrCPU: addmm_out_sparse_compressed_cpu @@ -8962,7 +8971,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: eq_Scalar_out + CPU, CUDA, MTIA: eq_Scalar_out MPS: eq_scalar_out_mps QuantizedCPU: eq_out_quantized_cpu tags: pointwise @@ -8981,7 +8990,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: eq_Tensor_out + CPU, CUDA, MTIA: eq_Tensor_out MPS: eq_tensor_out_mps QuantizedCPU: eq_out_quantized_cpu tags: pointwise @@ -9374,7 +9383,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: addcmul_out + CPU, CUDA, MTIA: addcmul_out MPS: addcmul_out_mps tags: pointwise @@ -9395,7 +9404,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: addcdiv_out + CPU, CUDA, MTIA: addcdiv_out MPS: addcdiv_out_mps tags: pointwise @@ -14960,7 +14969,6 @@ - func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) dispatch: CUDA: _scaled_dot_product_cudnn_attention_backward_cuda - NestedTensorCUDA: _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda tags: nondeterministic_seeded - func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) @@ -14993,11 +15001,6 @@ CUDA: _cudnn_attention_forward tags: nondeterministic_seeded -- func: _cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) - dispatch: - CUDA: _cudnn_attention_backward - tags: nondeterministic_seeded - - func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor variants: function dispatch: diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index 96c6ab8310f8..5b7476453407 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -349,63 +349,6 @@ _scaled_dot_product_cudnn_attention_nestedtensor_cuda( return std::make_tuple(std::move(attention), std::move(log_sumexp), cumulative_sequence_length_q, cumulative_sequence_length_kv, max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor()); } -std::tuple _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda( - const Tensor& grad_out, - const Tensor& query, - const Tensor& key, - const Tensor& value, - const Tensor& out, - const Tensor& logsumexp, - const Tensor& philox_seed, - const Tensor& philox_offset, - const Tensor& attn_bias, - const Tensor& cum_seq_q, - const Tensor& cum_seq_k, - const int64_t max_q, - const int64_t max_k, - double dropout_p, - bool is_causal, - std::optional scale) { - if (!grad_out.defined()) { - return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); - } - auto [ - grad_out_buffer_reshaped, - query_buffer_reshaped, - key_buffer_reshaped, - value_buffer_reshaped, - output_buffer_reshaped] = - preprocessing::sdpa_nested_preprocessing_backward( - grad_out, - query, - key, - value, - out, - cum_seq_q, - cum_seq_k, - max_q, - max_k); - - auto [dq, dk, dv] = at::_cudnn_attention_backward(grad_out_buffer_reshaped, - query_buffer_reshaped, - key_buffer_reshaped, - value_buffer_reshaped, - output_buffer_reshaped, - logsumexp, - philox_seed, - philox_offset, - attn_bias, - cum_seq_q, - cum_seq_k, - max_q, - max_k, - dropout_p, - is_causal, - scale); - return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); -} - - std::tuple _scaled_dot_product_flash_attention_backward_nested( const at::Tensor& grad_out_, const at::Tensor& query, diff --git a/aten/src/ATen/native/quantized/cpu/ACLUtils.cpp b/aten/src/ATen/native/quantized/cpu/ACLUtils.cpp index 7108ecd64cac..c689132c7692 100644 --- a/aten/src/ATen/native/quantized/cpu/ACLUtils.cpp +++ b/aten/src/ATen/native/quantized/cpu/ACLUtils.cpp @@ -81,7 +81,7 @@ DynamicQuantMatmul::DynamicQuantMatmul( auto src_q_tensor_info = arm_compute::TensorInfo( arm_compute::TensorShape(weight_dim_0, m), 1, - // ACL dyanamically quantized matmuls only support (signed) int8_t + // ACL dynamically quantized matmuls only support (signed) int8_t arm_compute::DataType::QASYMM8_SIGNED, // TODO: setting the initial offset value to int8_t max instead of zero, // because ACL currently skips MatrixBReduction calculation if the diff --git a/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h b/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h index 36f6140953f6..764d237e68b4 100644 --- a/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h +++ b/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h @@ -456,7 +456,7 @@ make_zero_points_and_scales_tensor( uint32_t groups = 1) { const int out_ch_idx = transpose ? 1 : 0; const auto num_output_channels = weight_contig.size(out_ch_idx) * (transpose ? groups : 1); - // Add 8 to account for bufferring needed by QNNPACK. + // Add 8 to account for buffering needed by QNNPACK. const auto num_output_channels_padded = num_output_channels + kPaddingChannels; const auto qtype = weight_contig.qscheme(); std::vector weight_zp(num_output_channels_padded, 0); diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp index f8651377ddf9..4cf3dfe2dbaa 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp @@ -366,7 +366,7 @@ Tensor ConvertConvWeightsToChannelLastTensor<3>( #endif // USE_FBGEMM namespace { - // This is really terrible, but couldnt figure out a better way to constexpr convert int to + // This is really terrible, but couldn't figure out a better way to constexpr convert int to // string and then perform string concatenation on/with it constexpr const char* _hack_int_to_class_name(int x) { switch(x) { diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h index e6d86cf03df1..7dc9a93365e3 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h @@ -7,11 +7,13 @@ #include #ifdef USE_FBGEMM +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Winconsistent-missing-destructor-override") #include C10_DIAGNOSTIC_POP() #include +C10_DIAGNOSTIC_POP() // The struct for the packed weight matrix (PackBMatrix) and the corresponding // column offsets used for the fully connect layer, which are both prepared in diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index cd4f253d0993..8624c9ef0336 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -1277,7 +1277,7 @@ at::Tensor PackedConvWeightsOnednn::apply_impl( float sum_scale = has_accum ? accum.value().q_scale() : 1.0; int32_t sum_zero_point = has_accum ? accum.value().q_zero_point() : 0; if (has_accum) { - // Just tells we have these post op, the actual value such as scale and zero point will be setted later. + // Just tells we have these post op, the actual value such as scale and zero point will be set later. op_attr = kReluFused ? ideep::attr_t::residual_with_sum_zero_point() : ideep::attr_t::fuse_sum(); const ideep::scale_t accum_scale = ideep::scale_t(1, 1.0/sum_scale); const ideep::zero_point_t accum_zero_points = ideep::zero_point_t(1, sum_zero_point); diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/quantized/cpu/qlinear.cpp index 502839a7d909..644ca6e67079 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear.cpp @@ -1118,8 +1118,9 @@ static at::Tensor linear_int8_with_onednn_weight( if(is_fp8 && !cpuinfo_has_x86_amx_int8()) { #endif // Fall back to ref impl on old platforms because not supported + // Transpose weight to align with behavior in oneDNN return fp8_qlinear_onednn_ref( - input, input_scale, onednn_weight, weight_scales, bias, + input, input_scale, onednn_weight.t(), weight_scales, bias, output_scale, output_dtype, other, other_scale, binary_post_op, binary_alpha, unary_post_op, unary_post_op_args, unary_post_op_algorithm); diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp index e2d5278d5792..4ed50f6f8735 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp @@ -888,7 +888,7 @@ class QLinearUnpackedDynamicFp16 final { static at::Tensor run( at::Tensor input, const at::Tensor& weight, - const at::Tensor& bias) { + const std::optional& bias) { // We make a strong guarantee that models using these operators will have // the same numerics across different machines. Therefore, we do not provide // a fallback path and rather fail loudly if we cannot run FBGEMM. @@ -908,7 +908,7 @@ class QLinearUnpackedDynamicFp16 final { static at::Tensor meta( at::Tensor input, const at::Tensor& weight, - const at::Tensor& bias) { + const std::optional& bias) { // We make a strong guarantee that models using these operators will have // the same numerics across different machines. Therefore, we do not provide // a fallback path and rather fail loudly if we cannot run FBGEMM. @@ -929,7 +929,7 @@ class QLinearUnpackedDynamicFp16 final { static at::Tensor run( at::Tensor /* input */, const at::Tensor& weight, - const at::Tensor& bias) { + const std::optional& bias) { // We make a strong guarantee that models using these operators will have // the same numerics across different machines. Therefore, we do not provide // a fallback path and rather fail loudly if we cannot run FBGEMM. @@ -940,7 +940,7 @@ class QLinearUnpackedDynamicFp16 final { static at::Tensor meta( at::Tensor /* input */, const at::Tensor& weight, - const at::Tensor& bias) { + const std::optional& bias) { TORCH_CHECK( false, "This PyTorch installation was not built with FBGEMM operators"); } diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index 55ec1d8148bd..3bd68feca1c2 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -305,11 +305,12 @@ static inline at::Tensor pack_weight_to_onednn_tensor( #if defined(__powerpc__) if (is_fp8){ #else - if(is_fp8 && !cpuinfo_has_x86_amx_int8()) { + if(is_fp8 && !cpuinfo_has_x86_amx_int8()) { #endif // oneDNN's fp8 requires AMX support // If AMX is not available, fall back to reference implementation - return weight; + // Transpose weight to align with behavior in oneDNN + return weight.t(); } std::vector w_dims = weight.sizes().vec(); auto w_data_type = is_fp8 diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-arm64.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-arm64.sh index 389430b043fe..5c52f1a020f1 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-arm64.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-arm64.sh @@ -53,7 +53,7 @@ CMAKE_ARGS+=("-DANDROID_PIE=ON") CMAKE_ARGS+=("-DANDROID_STL=c++_static") CMAKE_ARGS+=("-DANDROID_CPP_FEATURES=exceptions") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/android/arm64-v8a && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-armv7.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-armv7.sh index 6f32950125e0..81da44097801 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-armv7.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-armv7.sh @@ -53,7 +53,7 @@ CMAKE_ARGS+=("-DANDROID_PIE=ON") CMAKE_ARGS+=("-DANDROID_STL=c++_static") CMAKE_ARGS+=("-DANDROID_CPP_FEATURES=exceptions") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/android/armeabi-v7a && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-x86.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-x86.sh index 5f19db582fb0..747704f1edfe 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-x86.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-x86.sh @@ -53,7 +53,7 @@ CMAKE_ARGS+=("-DANDROID_PIE=ON") CMAKE_ARGS+=("-DANDROID_STL=c++_static") CMAKE_ARGS+=("-DANDROID_CPP_FEATURES=exceptions") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/android/x86 && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64.sh index d155d6f7507d..8e867f18d3f9 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64.sh @@ -40,7 +40,7 @@ CMAKE_ARGS+=("-DIOS_ARCH=arm64") CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") CMAKE_ARGS+=("-DENABLE_ARC=OFF") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/ios/arm64 && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64e.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64e.sh index 985315f74a66..34a95d194414 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64e.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64e.sh @@ -40,7 +40,7 @@ CMAKE_ARGS+=("-DIOS_ARCH=arm64e") CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") CMAKE_ARGS+=("-DENABLE_ARC=OFF") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/ios/arm64e && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7.sh index 0431c090db68..37e57ab557fc 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7.sh @@ -40,7 +40,7 @@ CMAKE_ARGS+=("-DIOS_ARCH=armv7") CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") CMAKE_ARGS+=("-DENABLE_ARC=OFF") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/ios/armv7 && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7s.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7s.sh index e3f3d6b76231..2fd273219111 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7s.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7s.sh @@ -40,7 +40,7 @@ CMAKE_ARGS+=("-DIOS_ARCH=armv7s") CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") CMAKE_ARGS+=("-DENABLE_ARC=OFF") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/ios/armv7s && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-i386.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-i386.sh index e8952148e66a..b51b574d8136 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-i386.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-i386.sh @@ -40,7 +40,7 @@ CMAKE_ARGS+=("-DIOS_ARCH=i386") CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") CMAKE_ARGS+=("-DENABLE_ARC=OFF") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/ios/i386 && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-x86_64.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-x86_64.sh index 10a58b843e2a..a3430082e3e5 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-x86_64.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-x86_64.sh @@ -45,7 +45,7 @@ CMAKE_ARGS+=("-DIOS_ARCH=x86_64") CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") CMAKE_ARGS+=("-DENABLE_ARC=OFF") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/ios/x86_64 && cmake ../../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-local.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-local.sh index b429650c2184..ac61a4061b90 100755 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-local.sh +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-local.sh @@ -27,7 +27,7 @@ CMAKE_ARGS+=("-DPYTORCH_QNNPACK_LIBRARY_TYPE=static") CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_BENCHMARKS=ON") CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_TESTS=ON") -# Use-specified CMake arguments go last to allow overridding defaults +# Use-specified CMake arguments go last to allow overriding defaults CMAKE_ARGS+=($@) cd build/local && cmake ../.. \ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/convolution.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/convolution.c index a37c53a11529..29f5338f5c73 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/convolution.c +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/convolution.c @@ -368,7 +368,7 @@ static enum pytorch_qnnp_status pytorch_qnnp_create_convolution_ndhwc_q8( case pytorch_qnnp_ukernel_type_xzp_gemm: { // TODO: XZP kernels won't be supporting per channel quantization. // For now we dont use XZP kernels anywhere. Probably deprecate it for now - // and ressurrect later if needed. + // and resurrect later if needed. const uint32_t nr = pytorch_qnnp_params.q8conv_xzp.nr; const uint32_t kr = pytorch_qnnp_params.q8conv_xzp.kr; const uint32_t sr = pytorch_qnnp_params.q8conv_xzp.kc; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-aarch32-neon.S index 75eab4a1c305..ac06fa5973ec 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-aarch32-neon.S @@ -20,28 +20,28 @@ # Args passed via stack. # TOS -# |-----------| -# |a | 0 -# |w | 4 -# |c | 8 -# |c_stride | 12 -# |out ch indx| 16 -# |params | 20 -# |-----------| +# |------------| +# |a | 0 +# |w | 4 +# |c | 8 +# |c_stride | 12 +# |out ch index| 16 +# |params | 20 +# |------------| # # After loading w pointer in ip reg. # And after pushing r4-r8 and d8-d15 on stack -# |-----------| -# |d8 - d15 | 0 -# |r4 - r11 | 64 -# |a | 96 -# |w | 100 -# |c | 104 -# |c_stride | 108 -# |out ch indx| 112 -# |params | 116 -# |-----------| +# |------------| +# |d8 - d15 | 0 +# |r4 - r11 | 64 +# |a | 96 +# |w | 100 +# |c | 104 +# |c_stride | 108 +# |out ch index| 112 +# |params | 116 +# |------------| # # void pytorch_q8conv_ukernel_4x8__aarch32_neon( diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/8x8-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/8x8-aarch64-neon.S index 95d0a2ca8eba..1653b46e2d37 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/8x8-aarch64-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/8x8-aarch64-neon.S @@ -23,10 +23,10 @@ # Args passed via stack. # TOS -# |-----------| -# |out ch indx| 0 -# |params | 8 -# |-----------| +# |------------| +# |out ch index| 0 +# |params | 8 +# |------------| # void pytorch_q8conv_ukernel_8x8__aarch64_neon( # size_t mr, diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-aarch32-neon.S index 8fbea6498dce..f18605124356 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-aarch32-neon.S @@ -20,28 +20,28 @@ # Args passed via stack. # TOS -# |-----------| -# |a_stride | 0 -# |w | 4 -# |c | 8 -# |c_stride | 12 -# |out ch indx| 16 -# |params | 20 -# |-----------| +# |------------| +# |a_stride | 0 +# |w | 4 +# |c | 8 +# |c_stride | 12 +# |out ch index| 16 +# |params | 20 +# |------------| # # After loading w pointer in ip reg. # And after pushing r4-r9 and d8-d15 on stack -# |-----------| -# |d8 - d15 | 0 -# |r4 - r9 | 64 -# |a_stride | 88 -# |w | 92 -# |c | 96 -# |c_stride | 100 -# |out ch indx| 104 -# |params | 108 -# |-----------| +# |------------| +# |d8 - d15 | 0 +# |r4 - r9 | 64 +# |a_stride | 88 +# |w | 92 +# |c | 96 +# |c_stride | 100 +# |out ch index| 104 +# |params | 108 +# |------------| # # diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-dq-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-dq-aarch32-neon.S index de564d9d3d5a..c964bf2be7c4 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-dq-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-dq-aarch32-neon.S @@ -33,29 +33,29 @@ # Args passed via stack. # TOS -# |-----------| -# |a_stride | 0 -# |w | 4 -# |c | 8 -# |c_stride | 12 -# |out ch indx| 16 -# |params | 20 -# |-----------| +# |------------| +# |a_stride | 0 +# |w | 4 +# |c | 8 +# |c_stride | 12 +# |out ch index| 16 +# |params | 20 +# |------------| # # After loading w pointer in ip reg. # And after pushing r4-r8 and d8-d15 on stack -# |-----------| -# |d8 - d15 | 0 -# |r4 - r7 | 64 -# |a_stride | 80 -# |w | 84 -# |b | 88 -# |c | 92 -# |c_stride | 96 -# |out ch indx| 100 -# |params | 104 -# |-----------| +# |------------| +# |d8 - d15 | 0 +# |r4 - r7 | 64 +# |a_stride | 80 +# |w | 84 +# |b | 88 +# |c | 92 +# |c_stride | 96 +# |out ch index| 100 +# |params | 104 +# |------------| # # void pytorch_q8gemm_ukernel_4x8__aarch32_neon( diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-aarch64-neon.S index 52913d752861..51866fd3b1ed 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-aarch64-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-aarch64-neon.S @@ -22,10 +22,10 @@ # Args passed via stack. # TOS -# |-----------| -# |out ch indx| 0 -# |params | 8 -# |-----------| +# |------------| +# |out ch index| 0 +# |params | 8 +# |------------| # void pytorch_q8gemm_ukernel_8x8__aarch64_neon( # size_t mr, diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-dq-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-dq-aarch64-neon.S index b8bde0200687..63f667b04a28 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-dq-aarch64-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-dq-aarch64-neon.S @@ -14,11 +14,11 @@ # Args passed via stack. # TOS -# |-----------| -# |c_stride | 0 -# |out ch indx| 8 -# |params | 16 -# |-----------| +# |------------| +# |c_stride | 0 +# |out ch index| 8 +# |params | 16 +# |------------| # void pytorch_q8gemm_dq_ukernel_8x8__aarch64_neon( # size_t mr, diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x4-packA-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x4-packA-aarch32-neon.S index f1dd0a2cc052..4583e50046d6 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x4-packA-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x4-packA-aarch32-neon.S @@ -32,7 +32,7 @@ # # Packed A format. -# 4kx4m blocks for alls blocks given 4 rows (4m) are placed in contiguous memory. +# 4kx4m blocks for all blocks given 4 rows (4m) are placed in contiguous memory. # Original A # --------- K ----------- -- (K + 4 - 1) / 4 -- # | | | | @@ -53,7 +53,7 @@ # This locality helps in loading 8kx4m blocks of activations # Note when M is not multiple of 4, the rest can contain arbitrary # data in packed A as we will not be writing those out. -# This wil be taken care by just copying the appropriate valid data +# This will be taken care by just copying the appropriate valid data # Also note that this packing is same as taking for 4x1 pattern. # This is because all the adjacent k's are laid next to each other @@ -109,7 +109,7 @@ k_loop: VLD1.8 {d2}, [r6]! VLD1.8 {d3}, [r7]! - # Now we have 4x8 block of values that we will tranpose + # Now we have 4x8 block of values that we will transpose # A matrix # -------------------------------- # | | @@ -155,7 +155,7 @@ k_loop: VTRN.32 d2, d3 VSWP d1, d2 - # Now store the tranposed values + # Now store the transposed values # d0, d1, d2, d3 VST1.8 {q0}, [r2]! VST1.8 {q1}, [r2]! @@ -172,7 +172,7 @@ k_loop: VLD1.32 {d2[]}, [r6] VLD1.32 {d3[]}, [r7] - # Now we have 4x8 block of values that we will tranpose + # Now we have 4x8 block of values that we will transpose # _d{0-3} are arm neon vector registers # va0 = _d0 = a0 a1 a2 a3 # va1 = _d1 = b0 b1 b2 b3 @@ -218,7 +218,7 @@ k_loop: VEXT.8 d0, d0, d1, #4 VEXT.8 d1, d2, d3, #4 - # Now store the tranposed values + # Now store the transposed values # d0, d1, d2, d3 VST1.8 {q0}, [r2] .p2align 4 diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c1x4-dq-packedA-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c1x4-dq-packedA-aarch32-neon.S index 5b796bb2563c..d7a3aa6eaaf7 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c1x4-dq-packedA-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c1x4-dq-packedA-aarch32-neon.S @@ -46,7 +46,7 @@ # |b | 12 # |c | 16 # |c_stride | 20 -# |out ch indx | 24 +# |out ch index | 24 # |params | 28 # |----------------| # @@ -61,7 +61,7 @@ # |b | 108 # |c | 112 # |c_stride | 116 -# |out ch indx | 120 +# |out ch index | 120 # |params | 124 # |----------------| # @@ -101,7 +101,7 @@ /* Add output_channel_index to the b_zero_point pointer */ ;\ ADD r4, r4, r5 ;\ ;\ - /* We enter the loop if r1 is atleast 1. */ ;\ + /* We enter the loop if r1 is at least 1. */ ;\ /* r1 = r1 - 1 will happen in the epilogue */ ;\ /* of the loop */ ;\ CMP r1, 1 ;\ @@ -222,7 +222,7 @@ /* Thus we will load accumulators back in q0, q1, q2, q3, q4, q5, q6, q7 */ ;\ /* When nr < 4, extra q values will be fetched from stack which may overlap */ ;\ /* with other parts of stack storing local variables. To avoid that we just */ ;\ - /* create a buffer of 128 bytes inbetween to make sure pointer increment */ ;\ + /* create a buffer of 128 bytes in between to make sure pointer increment */ ;\ /* never produces address that is beyond the stack frame of this function. */ ;\ SUB r9, sp, 140 ;\ /* Each iteration produce 4 values each of 4 bytes */ ;\ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c8x1-dq-packedA-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c8x1-dq-packedA-aarch32-neon.S index dd829f80e373..37db2adcad06 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c8x1-dq-packedA-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/4x8c8x1-dq-packedA-aarch32-neon.S @@ -46,7 +46,7 @@ # |b | 12 # |c | 16 # |c_stride | 20 -# |out ch indx | 24 +# |out ch index | 24 # |params | 28 # |----------------| # @@ -61,7 +61,7 @@ # |b | 108 # |c | 112 # |c_stride | 116 -# |out ch indx | 120 +# |out ch index | 120 # |params | 124 # |----------------| # diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch32-neon.S index bff19de739b1..a5a91b9cb64f 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch32-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch32-neon.S @@ -32,7 +32,7 @@ # # Packed A format. -# 8kx4m blocks for alls blocks given 4 rows (4m) are placed in contiguous memory. +# 8kx4m blocks for all blocks given 4 rows (4m) are placed in contiguous memory. # Original A # --------- K ----------- -- (K + 4 - 1) / 4 -- # | | | | @@ -53,7 +53,7 @@ # This locality helps in loading 8kx8m blocks of activations # Note when M is not multiple of 8, the rest can contain arbitrary # data in packed A as we will not be writing those out. -# This wil be taken care by just copying the appropriate valid data +# This will be taken care by just copying the appropriate valid data # void pytorch_q8gemm_sparse_packA_ukernel_8x4__aarch32_neon( # size_t mr, @@ -125,7 +125,7 @@ k_loop: VLD1.8 {d6}, [r10]! VLD1.8 {d7}, [r11]! - # Now we have 8x8 block of values that we will tranpose + # Now we have 8x8 block of values that we will transpose # A matrix # -------------------------------- # | | @@ -189,7 +189,7 @@ k_loop: VTRN.32 q0, q2 VTRN.32 q1, q3 - # Now store the tranposed values + # Now store the transposed values # d0, d1, d2, d3 # then d4, d5, d6, d7 contiguously VST1.8 {q0}, [r2]! @@ -213,7 +213,7 @@ k_loop: VLD1.32 {d6[]}, [r7] VLD1.32 {d7[]}, [r11] - # Now we have 4x8 block of values that we will tranpose + # Now we have 4x8 block of values that we will transpose # _d{0-3} are arm neon vector registers # va04 = _d0 = a0 a1 a2 a3 e0 e1 e2 e3 # va15 = _d1 = b0 b1 b2 b3 f0 f1 f2 f3 @@ -260,7 +260,7 @@ k_loop: VTRN.16 d0, d2 VTRN.16 d1, d3 - # Now store the tranposed values + # Now store the transposed values # d0, d1, d2, d3 # then d4, d5, d6, d7 contiguously VST1.8 {q0}, [r2]! diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch64-neon.S index 4cd788cf583b..b1f8fe719ca4 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch64-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-aarch64-neon.S @@ -9,7 +9,7 @@ #include # Packed A format. -# 8kx4m blocks for alls blocks given 4 rows (4m) are placed in contiguous memory. +# 8kx4m blocks for all blocks given 4 rows (4m) are placed in contiguous memory. # Original A # --------- K ----------- -- (K + 4 - 1) / 4 -- # | | | | @@ -30,7 +30,7 @@ # This locality helps in loading 8kx8m blocks of activations # Note when M is not multiple of 8, the rest can contain arbitrary # data in packed A as we will not be writing those out. -# This wil be taken care by just copying the appropriate valid data +# This will be taken care by just copying the appropriate valid data # void pytorch_q8gemm_sparse_packA_ukernel_8x4__aarch32_neon( # size_t mr, @@ -93,7 +93,7 @@ k_loop: LD1 {v3.d}[0], [x7], 8 LD1 {v3.d}[1], [x11], 8 - # Now we have 8x8 block of values that we will tranpose + # Now we have 8x8 block of values that we will transpose # A matrix # ------------------------ # | | @@ -180,7 +180,7 @@ k_loop: LD1 {v3.s}[0], [x7] LD1 {v3.s}[1], [x11] - # Now we have 8x4 block of values that we will tranpose + # Now we have 8x4 block of values that we will transpose # A matrix # ---------------------------- # | | diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-sse2.c index 4b0dd46fd4cf..df707d3d800e 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-sse2.c +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4-packA-sse2.c @@ -14,7 +14,7 @@ #include "8x4c1x4-packed-sse2.h" // This is a super slow kernel in that it does not use intrinsics to -// tranpose. Since this is for x86 we are not optimizing it. +// transpose. Since this is for x86 we are not optimizing it. // For ARM this will be optimized. void pytorch_q8gemm_sparse_packA_ukernel_8x4__sse2( const size_t mr, @@ -24,7 +24,7 @@ void pytorch_q8gemm_sparse_packA_ukernel_8x4__sse2( uint8_t* a_packed) { // Packed A format. - // 8kx4m blocks for alls blocks given 4 rows (4m) are placed in contiguous memory. + // 8kx4m blocks for all blocks given 4 rows (4m) are placed in contiguous memory. // Original A // --------- K ----------- -- (K + 4 - 1) / 4 -- // | | | | @@ -45,7 +45,7 @@ void pytorch_q8gemm_sparse_packA_ukernel_8x4__sse2( // This locality helps in loading 8kx8m blocks of activations // Note when M is not multiple of 8, the rest can contain arbitrary // data in packed A as we will not be writing those out. - // This wil be taken care by just copying the appropriate valid data + // This will be taken care by just copying the appropriate valid data // Note that parts of A that are not filled are: // Remainder of M blocks. So some m values are random. This is ok diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4c1x4-dq-packedA-sse2.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4c1x4-dq-packedA-sse2.h index 5503d6718172..ef771b4187b8 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4c1x4-dq-packedA-sse2.h +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x4c1x4-dq-packedA-sse2.h @@ -47,7 +47,7 @@ void KERNEL_NAME( const __m128i vzero = _mm_setzero_si128(); // Packed A format. - // 8kx4m blocks for alls blocks given 4 rows (4m) are placed in contiguous memory. + // 8kx4m blocks for all blocks given 4 rows (4m) are placed in contiguous memory. // Original A // --------- K ----------- -- (K + 4 - 1) / 4 -- // | | | | @@ -68,7 +68,7 @@ void KERNEL_NAME( // This locality helps in loading 8kx8m blocks of activations // Note when M is not multiple of 8, the rest can contain arbitrary // data in packed A as we will not be writing those out. - // This wil be taken care by just copying the appropriate valid data + // This will be taken care by just copying the appropriate valid data __m128i vacc_low[4]; __m128i vacc_high[4]; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c1x4-dq-packedA-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c1x4-dq-packedA-aarch64-neon.S index aca408e89757..8af5c417da31 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c1x4-dq-packedA-aarch64-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c1x4-dq-packedA-aarch64-neon.S @@ -42,11 +42,11 @@ # Args passed via stack. # TOS -# |-----------| -# |c_stride | 0 -# |out ch indx| 8 -# |params | 16 -# |-----------| +# |------------| +# |c_stride | 0 +# |out ch index| 8 +# |params | 16 +# |------------| # void pytorch_q8gemm_dq_sparse_1x4_ukernel_8x8_packedA_w##W_INDEX_DTYPE_NUM_BITS##__aarch64_neon( # size_t mr, @@ -234,7 +234,7 @@ /* v16, v17, v18, v19, v20, v21, v22, v23 */ XX\ /* When nr < 8, say nr = 1, extra v values will be fetched from stack which may overlap */ XX\ /* with other parts of stack storing local variables. To avoid that we just */ XX\ - /* create a buffer of 256 bytes inbetween to make sure pointer increment */ XX\ + /* create a buffer of 256 bytes in between to make sure pointer increment */ XX\ /* never produces address that is beyond the stack frame of this function. */ XX\ SUB x9, sp, 320 XX\ /* Each iteration produce 8 values each of 4 bytes */ XX\ @@ -287,7 +287,7 @@ LD1 {v22.4s}, [x9], 16 XX\ LD1 {v23.4s}, [x9] XX\ XX\ - /* We can tranpose one 4x4 block using macro */ XX\ + /* We can transpose one 4x4 block using macro */ XX\ /* TRANSPOSE_4X4_S32 v8, v10, v12, v14, v0, v1, v2, v3 */ XX\ /* After this we have */ XX\ /* v8 : x00, x01, x02, x03 */ XX\ @@ -302,7 +302,7 @@ /* v20 : x24, x25, x26, x27 */ XX\ /* v22 : x34, x35, x36, x37 */ XX\ /* Similarly we can transpose other two 4x4 blocks and we get */ XX\ - /* tranposed 8x8 */ XX\ + /* transposed 8x8 */ XX\ XX\ TRANSPOSE_4X4_S32 v8, v10, v12, v14, v0, v1, v2, v3 XX\ TRANSPOSE_4X4_S32 v16, v18, v20, v22, v4, v5, v6, v7 XX\ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c8x1-dq-packedA-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c8x1-dq-packedA-aarch64-neon.S index 2ba033c57c83..58602beb030d 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c8x1-dq-packedA-aarch64-neon.S +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm_sparse/8x8c8x1-dq-packedA-aarch64-neon.S @@ -31,11 +31,11 @@ # Args passed via stack. # TOS -# |-----------| -# |c_stride | 0 -# |out ch indx| 8 -# |params | 16 -# |-----------| +# |------------| +# |c_stride | 0 +# |out ch index| 8 +# |params | 16 +# |------------| # void pytorch_q8gemm_dq_sparse_8x1_ukernel_8x8_packedA_w##W_INDEX_DTYPE_NUM_BITS##__aarch64_neon( # size_t mr, diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/pack.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/pack.h index 14ea25612485..14365d1ab3dd 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/pack.h +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/pack.h @@ -238,7 +238,7 @@ static inline void pytorch_pack_q8conv_wrq( } } if (kzp != 0) { - // This part fills the packed wights with zero points for output channels + // This part fills the packed weights with zero points for output channels // when they are not divisible by nr blocking parameter. // In that case for (size_t nr_block_offset = 0; nr_block_offset < (nr - nr_block_size); @@ -360,7 +360,7 @@ static inline void pytorch_pack_q8deconv_wrq( } } if (kzp != 0) { - // This part fills the packed wights with zero points for output channels + // This part fills the packed weights with zero points for output channels // when they are not divisible by nr blocking parameter. // In that case for (size_t nr_block_offset = 0; nr_block_offset < (nr - nr_block_size); diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-scalar.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-scalar.c index e86130f2ccb6..74961b51ff63 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-scalar.c +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-scalar.c @@ -93,7 +93,7 @@ void pytorch_qnnp_requantize_q31__scalar( * overflow is possible only when input is positive, and even when addition * of a rounding constant overflows 32-bit signed integer, it still doesn't * overflow 32-bit unsigned integer. Thus, in case of signed overflow, we - * can compute the result using unsigned arithmetics, specifically using + * can compute the result using unsigned arithmetic, specifically using * logical shift right instead of arithmetic shift right. * 3. Performs arithmetic shift as is, which will produce division result * rounded down. Then compute remainder of this division by a power of 2, diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/requantization.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/requantization.cc index a837974dd9fc..f535e4b99ed7 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/test/requantization.cc +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/requantization.cc @@ -17,7 +17,7 @@ #include "requantization-tester.h" /* - * Precise scalar implementation using unsigned 32-bit arithmetics. + * Precise scalar implementation using unsigned 32-bit arithmetic. */ TEST(PRECISE__SCALAR_UNSIGNED32, exact_divide_by_po2) { @@ -83,7 +83,7 @@ TEST(PRECISE__SCALAR_UNSIGNED32, random_cases) { } /* - * Precise scalar implementation using unsigned 64-bit arithmetics. + * Precise scalar implementation using unsigned 64-bit arithmetic. */ TEST(PRECISE__SCALAR_UNSIGNED64, exact_divide_by_po2) { @@ -149,7 +149,7 @@ TEST(PRECISE__SCALAR_UNSIGNED64, random_cases) { } /* - * Precise scalar implementation using signed 64-bit arithmetics. + * Precise scalar implementation using signed 64-bit arithmetic. */ TEST(PRECISE__SCALAR_SIGNED64, exact_divide_by_po2) { @@ -302,7 +302,7 @@ TEST(GEMMLOWP__SCALAR, random_cases) { } /* - * Precise PSIMD implementation using unsigned 32-bit arithmetics. + * Precise PSIMD implementation using unsigned 32-bit arithmetic. */ TEST(PRECISE__PSIMD, exact_divide_by_po2) { diff --git a/aten/src/ATen/native/quantized/cudnn/Linear.cpp b/aten/src/ATen/native/quantized/cudnn/Linear.cpp index ea776fdf450f..230850998fda 100644 --- a/aten/src/ATen/native/quantized/cudnn/Linear.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Linear.cpp @@ -171,7 +171,7 @@ void PackedLinearWeightCudnn::apply_impl_helper(const at::Tensor& quantized_outp return; } - // linear_op computes act_int8 * tranpose(w_int8) (matrix multiplication) + // linear_op computes act_int8 * transpose(w_int8) (matrix multiplication) // where act_int8 and w_int8 are the input and weight variables, resp. // output is a fp32 tensor auto linear_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) diff --git a/aten/src/ATen/native/quantized/cudnn/Pooling.cpp b/aten/src/ATen/native/quantized/cudnn/Pooling.cpp index ba2cc9592d6c..7fe44de11e54 100644 --- a/aten/src/ATen/native/quantized/cudnn/Pooling.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Pooling.cpp @@ -54,7 +54,7 @@ void check_maxpool2d_params( Tensor adaptive_avg_pool2d_quantized_cuda( const at::Tensor& input, IntArrayRef output_size) { -// TODO: renable these cudnn preprocessors like quantized_max_pool2d_cudnn below when we implement this function with cudnn +// TODO: re-enable these cudnn preprocessors like quantized_max_pool2d_cudnn below when we implement this function with cudnn #ifdef USE_CUDA // #if AT_CUDNN_ENABLED() // TODO: limit this to per tensor quantized tensors for now, though should be easy to adapt diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index cb19ec10ce04..550280dbf6d3 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -142,7 +142,7 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, bool reduce_range=False) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); - m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16_unpacked_weight(Tensor X, Tensor weight, Tensor bias) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_dynamic_fp16_unpacked_weight(Tensor X, Tensor weight, Tensor? bias) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_relu_dynamic_fp16(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_leaky_relu(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i, float negative_slope) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::linear_tanh(Tensor X, __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y"), {at::Tag::pt2_compliant_tag}); diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h index d7da40750ba1..805035cdd626 100644 --- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h +++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h @@ -51,10 +51,10 @@ ForwardIt find_bound(ForwardIt first, ForwardIt last, const T& value) { // Similarly, an upper bound is a value at *it with the smallest index // such that *it > value if such value exists, or last if does not. // Let is_lower = true and *it < value, then we know that *it and values - // preceeding *it cannot contain a lower bound, so we adjust initial iterator range + // preceding *it cannot contain a lower bound, so we adjust initial iterator range // from [first, first + count] to [first + step + 1, first + count - (step + 1)], // where +1 skips the element at which we have just evaluated *it < value. - // Samilar logic holds when is_lower = false. + // Similar logic holds when is_lower = false. if (is_lower ? *it < value : value >= *it) { first = ++it; count -= step + 1; diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp index 20a44c870939..cf854a84e7da 100644 --- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp +++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionKernel.cpp @@ -79,7 +79,7 @@ struct CPUValueSelectionIntersectionKernel { const auto* ptr_argsort = argsort.const_data_ptr(); for (int64_t i = 0; i < n; ++i) { - // Exctract data + // Extract data auto* ptr_res_values = reinterpret_cast(ptr_res_values_bytes); const auto* ptr_lhs_values = reinterpret_cast(ptr_lhs_values_bytes); const auto lhs_nnz_idx = *reinterpret_cast(ptr_lhs_select_idx_bytes); diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index b63d8ae80e50..752365d545de 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -730,7 +730,7 @@ static std::tuple sparse_mask_like_prepare_sparse_inp // is that these primitives might project first argument onto second one or // the other way around depending on which arguments are coalesced and which are // larger. This function prepares inputs for `sparse_mask` such that `t` is - // projected onto `mask` by sorting `t` if uncoalesced and artifically marking it + // projected onto `mask` by sorting `t` if uncoalesced and artificially marking it // as coalesced all while `mask` is set to uncoalesced. // The result of this projectionk is going to be uncoalesced, so it is up to the // user to set the corresponding flag correctly with respect to the operations' diff --git a/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h b/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h index ec4c084a39cc..267c19561a29 100644 --- a/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h +++ b/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h @@ -242,7 +242,7 @@ void _validate_compressed_sparse_indices_kernel( // Catch integer overflow from large dimensions. Otherwise, the // invariant checks may fail with bogus exceptions or succeed with // false-positive results when int64_t typed dimensions are cast to - // index dtype that corresponds to smaller interger type such as + // index dtype that corresponds to smaller integer type such as // int32_t. { AT_DISPATCH_INDEX_TYPES(idx.scalar_type(), NAME, [cdim, dim, nnz]() { diff --git a/aten/src/ATen/native/sparse/cuda/ComputeSparseTile.h b/aten/src/ATen/native/sparse/cuda/ComputeSparseTile.h index f902f1e61c5e..530804099b6f 100644 --- a/aten/src/ATen/native/sparse/cuda/ComputeSparseTile.h +++ b/aten/src/ATen/native/sparse/cuda/ComputeSparseTile.h @@ -112,7 +112,7 @@ struct LargestValuesGreedy { } }; -// We consider each rows independantly in order +// We consider each rows independently in order // This is to ensure that a row's sparsity pattern is only determined // by its values and the rows before (but never the rows after) // This enforces causality strictly diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh b/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh index c9412d74e9cd..693ca536a319 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh @@ -242,7 +242,11 @@ __global__ void coalesceValuesKernel( // `if constexpr` when CUDA codes will be compiled under C++-17, see // gh-56055 for blockers. template +#ifdef USE_ROCM +C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE_STATIC*4) +#else C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE*4) +#endif __global__ void coalesceValuesKernel( int64_t *segment_offsets, int64_t *value_indices, bool *values, bool *newValues, diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp index 582778fdc299..c656dc71a660 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp @@ -65,7 +65,7 @@ void _csrmm2( csrvala, /* values of the sparse matrix, size = nnz */ CUSPARSE_INDEX_32I, /* data type of row offsets index */ CUSPARSE_INDEX_32I, /* data type of col indices */ - CUSPARSE_INDEX_BASE_ZERO, /* base index of row offset and col indes */ + CUSPARSE_INDEX_BASE_ZERO, /* base index of row offset and col index */ cusparse_value_type /* data type of values */ )); diff --git a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu index 867f103ba518..c6e3197a22a8 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu @@ -93,7 +93,7 @@ void create_general_description_(cusparseMatDescr_t& description_) { } // csrMatrixRef is used to have a representation of a raw CSR matrix representation -// comming from `sparse_sparse_matmul_cuda_kernel` function. +// coming from `sparse_sparse_matmul_cuda_kernel` function. // Moreover this implements a RAII guard for a cusparse descriptor template struct csrMatrixRef { diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 8647a199ad8e..7aad4309924d 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -207,7 +207,7 @@ Tensor qkv_projection( } else { // encoder-decoder attention // TODO: is there a more efficient way to set this up? - // TODO: can we stay nested insted of using cat? Probably just make a + // TODO: can we stay nested instead of using cat? Probably just make a // NestedTensor out of the matmul results or something? auto q_kv_weight_s = at::native::split_with_sizes(qkv_weight, {embed_dim, embed_dim * 2}, 0); diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 8513419db0a9..80049aa9a832 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -849,6 +849,16 @@ std::tuple #include -#include -#include #include #include #include @@ -98,14 +96,14 @@ std::tuple _flash_attention_backward( std::optional dk{std::nullopt}; std::optional dv{std::nullopt}; - // The kernel computes irregardless we will drop for this functions return + // The kernel computes regardless we will drop for this functions return Tensor grad_softmax; // Currently unused args: std::optional alibi_slopes{std::nullopt}; const float softcap = 0.0; - bool determinisitic{false}; + bool deterministic{false}; auto& ctx = at::globalContext(); if (ctx.deterministicAlgorithms()) { if (ctx.deterministicAlgorithmsWarnOnly()) { @@ -113,7 +111,7 @@ std::tuple _flash_attention_backward( "Flash Attention defaults to a non-deterministic algorithm. ", "To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False)."); } else { - determinisitic = true; + deterministic = true; } } @@ -148,7 +146,7 @@ std::tuple _flash_attention_backward( non_null_window_right, #endif softcap, - determinisitic, + deterministic, philox_seed, philox_offset); return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue)); @@ -176,7 +174,7 @@ std::tuple _flash_attention_backward( non_null_window_right, #endif softcap, - determinisitic, + deterministic, philox_seed, philox_offset); return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue)); @@ -186,7 +184,7 @@ std::tuple _flash_attention_backward( return std::make_tuple(Tensor(), Tensor(), Tensor()); } -std::tuple _cudnn_attention_backward( +std::tuple _scaled_dot_product_cudnn_attention_backward_cuda( const Tensor& grad_out, const Tensor& query, const Tensor& key, @@ -213,117 +211,57 @@ std::tuple _cudnn_attention_backward( } } - const bool is_nested = cum_seq_q.defined(); + const int64_t batch_size = query.size(0); + const int64_t num_heads = query.size(1); + const int64_t head_dim_qk = query.size(3); + const int64_t head_dim_v = value.size(3); const int64_t max_seqlen_batch_q = query.size(2); const int64_t max_seqlen_batch_k = key.size(2); - if (!is_nested) { - const int64_t batch_size = query.size(0); - const int64_t num_heads = query.size(1); - const int64_t head_dim_qk = query.size(3); - const int64_t head_dim_v = value.size(3); - - // This is needed because SaveVariable automatically converts - // std::optional to undefined tensor - std::optional attn_bias_; - if (attn_bias.defined()) { - attn_bias_ = attn_bias; - } - if (attn_bias_.has_value()) { - const auto bias_dim = attn_bias_.value().dim(); - if (bias_dim == 2) { - attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); - } else if (bias_dim == 3) { - attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); - } else { - TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); - attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); - } - } - - const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); - auto dq = at::empty_like(query); - auto dk = at::empty_like(key); - auto dv = at::empty_like(value); - run_cudnn_SDP_bprop(batch_size /*int64_t b*/, - num_heads /*int64_t h*/, - max_q/*int64_t s_q*/, - max_k/*int64_t s_kv*/, - head_dim_qk /*int64_t d_qk*/, - head_dim_v /*int64_t d_v*/, - softmax_scale /*float scaling_factor*/, - is_causal /*bool is_causal*/, - dropout_p /*float dropout_probability*/, - query /*const Tensor& q*/, - key /*const Tensor& k*/, - value /*const Tensor& v*/, - attn_bias_ /*const std::optional& attn_bias*/, - out /*const Tensor& o*/, - grad_out/*const Tensor& dO*/, - logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/, - dq/*Tensor& dQ*/, - dk/*Tensor& dK*/, - dv/*Tensor& dV*/, - philox_seed/*Tensor& dropoutseed*/, - philox_offset/*Tensor& dropoutoffset*/); - return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); - } else { - // BHSD ... - const int64_t batch_size = cum_seq_q.size(0) - 1; - const int64_t num_heads_q = query.size(-2); - const int64_t num_heads_k = key.size(-2); - const int64_t num_heads_v = value.size(-2); - const int64_t head_dim_qk = query.size(-1); - const int64_t head_dim_v = value.size(-1); - std::optional attn_bias_; - if (attn_bias.defined()) { - attn_bias_ = attn_bias; - } - if (attn_bias_.has_value()) { - const auto bias_dim = attn_bias_.value().dim(); - if (bias_dim == 2) { - attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); - } else if (bias_dim == 3) { - attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); - } else { - attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); - TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); - } + // This is needed because SaveVariable automatically converts + // std::optional to undefined tensor + std::optional attn_bias_; + if (attn_bias.defined()) { + attn_bias_ = attn_bias; + } + if (attn_bias_.has_value()) { + const auto bias_dim = attn_bias_.value().dim(); + if (bias_dim == 2) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else if (bias_dim == 3) { + attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); + } else { + TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); + attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); } - - auto dq = at::empty_like(query); - auto dk = at::empty_like(key); - auto dv = at::empty_like(value); - - const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked(); - run_cudnn_SDP_bprop_nestedtensor( - batch_size, - num_heads_q, - num_heads_k, - num_heads_v, - max_seqlen_batch_q, - max_seqlen_batch_k, - head_dim_qk, - head_dim_v, - softmax_scale, - is_causal, - dropout_p, - cum_seq_q, - cum_seq_k, - query, - key, - value, - attn_bias_, - out, - grad_out, - logsumexp, - dq, - dk, - dv, - philox_seed, - philox_offset); - return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); } + + const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); + auto dq = at::empty_like(query); + auto dk = at::empty_like(key); + auto dv = at::empty_like(value); + run_cudnn_SDP_bprop(batch_size /*int64_t b*/, + num_heads /*int64_t h*/, + max_q/*int64_t s_q*/, + max_k/*int64_t s_kv*/, + head_dim_qk /*int64_t d_qk*/, + head_dim_v /*int64_t d_v*/, + softmax_scale /*float scaling_factor*/, + is_causal /*bool is_causal*/, + dropout_p /*float dropout_probability*/, + query /*const Tensor& q*/, + key /*const Tensor& k*/, + value /*const Tensor& v*/, + attn_bias_ /*const std::optional& attn_bias*/, + out /*const Tensor& o*/, + grad_out/*const Tensor& dO*/, + logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/, + dq/*Tensor& dQ*/, + dk/*Tensor& dK*/, + dv/*Tensor& dV*/, + philox_seed/*Tensor& dropoutseed*/, + philox_offset/*Tensor& dropoutoffset*/); + return std::make_tuple(std::move(dq), std::move(dk), std::move(dv)); } std::tuple @@ -1125,40 +1063,4 @@ std::tuple _scaled_dot_product_e } } -std::tuple _scaled_dot_product_cudnn_attention_backward_cuda( - const Tensor& grad_out, - const Tensor& query, - const Tensor& key, - const Tensor& value, - const Tensor& out, - const Tensor& logsumexp, - const Tensor& philox_seed, - const Tensor& philox_offset, - const Tensor& attn_bias, - const Tensor& cum_seq_q, - const Tensor& cum_seq_k, - const int64_t max_q, - const int64_t max_k, - double dropout_p, - bool is_causal, - std::optional scale) { - return at::_cudnn_attention_backward( - grad_out, - query, - key, - value, - out, - logsumexp, - philox_seed, - philox_offset, - attn_bias, - cum_seq_q, - cum_seq_k, - max_q, - max_k, - dropout_p, - is_causal, - scale); -} - } // namespace at::native diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp index 9eed9b69d8bd..a4e37da1a4ae 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp @@ -32,7 +32,9 @@ #endif +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include +C10_DIAGNOSTIC_POP() #include @@ -389,20 +391,14 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head std::optional gen_) { auto dprops = at::cuda::getCurrentDeviceProperties(); - // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - bool is_sm10x = dprops->major == 10 && dprops->minor >= 0; - bool is_sm120_or_sm121 = dprops->major == 12 && dprops->minor <= 1; - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); - // We will support Turing in the near future - // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + bool is_sm80_or_newer = (dprops->major * 10) >= 80; + TORCH_CHECK(is_sm80_or_newer, "FlashAttention only supports Ampere GPUs or newer."); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); if (q_dtype == at::kBFloat16) { - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + TORCH_CHECK(is_sm80_or_newer, "bfloat16 is only supported on Ampere GPUs or newer"); } TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); @@ -577,20 +573,14 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q std::optional gen_) { auto dprops = at::cuda::getCurrentDeviceProperties(); - // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - bool is_sm10x = dprops->major == 10 && dprops->minor >= 0; - bool is_sm120_or_sm121 = dprops->major == 12 && dprops->minor <= 1; - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); - // We will support Turing in the near future - // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + bool is_sm80_or_newer = (dprops->major * 10) >= 80; + TORCH_CHECK(is_sm80_or_newer, "FlashAttention only supports Ampere GPUs or newer."); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); if (q_dtype == at::kBFloat16) { - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + TORCH_CHECK(is_sm80_or_newer, "bfloat16 is only supported on Ampere GPUs or newer"); } TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); @@ -838,15 +828,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si #endif if (is_causal) { window_size_right = 0; } auto dprops = at::cuda::getCurrentDeviceProperties(); - // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm80 = dprops->major == 8 && dprops->minor == 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - bool is_sm10x = dprops->major == 10 && dprops->minor >= 0; - bool is_sm120_or_sm121 = dprops->major == 12 && dprops->minor <= 1; - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); - // We will support Turing in the near future - // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + bool is_sm80_or_newer = (dprops->major * 10) >= 80; + TORCH_CHECK(is_sm80_or_newer, "FlashAttention only supports Ampere GPUs or newer."); bool is_dropout = p_dropout > 0.0; auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -855,7 +838,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); if (q_dtype == at::kBFloat16) { - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + TORCH_CHECK(is_sm80_or_newer, "bfloat16 is only supported on Ampere GPUs or newer"); } TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); @@ -885,7 +868,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); if (head_size > 192 && (head_size <= 224 || is_dropout)) { - TORCH_CHECK(is_sm80 || is_sm90 || is_sm10x || is_sm120_or_sm121, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800"); + TORCH_CHECK(is_sm80_or_newer, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800"); } TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); @@ -1055,15 +1038,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size if (is_causal) { window_size_right = 0; } auto dprops = at::cuda::getCurrentDeviceProperties(); - // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm80 = dprops->major == 8 && dprops->minor == 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - bool is_sm10x = dprops->major == 10 && dprops->minor >= 0; - bool is_sm120_or_sm121 = dprops->major == 12 && dprops->minor <= 1; - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); - // We will support Turing in the near future - // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + bool is_sm80_or_newer = (dprops->major * 10) >= 80; + TORCH_CHECK(is_sm80_or_newer, "FlashAttention only supports Ampere GPUs or newer."); + bool is_dropout = p_dropout > 0.0; auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -1071,7 +1048,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); if (q_dtype == at::kBFloat16) { - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + TORCH_CHECK(is_sm80_or_newer, "bfloat16 is only supported on Ampere GPUs or newer"); } TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); @@ -1106,7 +1083,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); if (head_size > 192 && (head_size <= 224 || is_dropout)) { - TORCH_CHECK(is_sm80 || is_sm90 || is_sm10x || is_sm120_or_sm121, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800"); + TORCH_CHECK(is_sm80_or_newer, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800"); } TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); @@ -1280,20 +1257,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he ) { auto dprops = at::cuda::getCurrentDeviceProperties(); - // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; - bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; - bool is_sm90 = dprops->major == 9 && dprops->minor == 0; - bool is_sm10x = dprops->major == 10 && dprops->minor >= 0; - bool is_sm120_or_sm121 = dprops->major == 12 && dprops->minor <= 1; - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer."); - // We will support Turing in the near future - // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer."); + bool is_sm80_or_newer = (dprops->major * 10) >= 80; + TORCH_CHECK(is_sm80_or_newer, "FlashAttention only supports Ampere GPUs or newer."); auto q_dtype = q.dtype(); TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); if (q_dtype == at::kBFloat16) { - TORCH_CHECK(is_sm120_or_sm121 || is_sm10x || is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); + TORCH_CHECK(is_sm80_or_newer, "bfloat16 is only supported on Ampere GPUs or newer"); } TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype"); @@ -1328,7 +1299,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size; const int num_heads_k = kcache.size(2); const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size; - TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h index fd7982e5f699..7115cb07a793 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h @@ -125,7 +125,7 @@ class MemoryEfficientAttentionNormalize { FragmentSource const& source) const { assert(!isFirst); - // Convert source to interal compute numeric type + // Convert source to internal compute numeric type NumericArrayConverter source_converter; NumericArrayConverter @@ -164,7 +164,7 @@ class MemoryEfficientAttentionNormalize { const { assert(isFirst); - // Convert source to interal compute numeric type + // Convert source to internal compute numeric type NumericArrayConverter accumulator_converter; diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h index 229c59d68347..3c3566512b45 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h @@ -88,7 +88,7 @@ class CustomMmaBase { Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>; - /// Number of warp-level GEMM oeprations + /// Number of warp-level GEMM operations static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h index c7a9915fed6d..e75a1b9001e0 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h @@ -68,7 +68,7 @@ namespace threadblock { /// ForwardTileIterator /// template < - typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename ThreadMap_, ///< Thread map (concept: OutputTileThreadMap) typename Element_, ///< Element data type bool ScatterD = false, ///< Scatter D operand or not bool UseCUDAStore = false> diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h index ae649e99c4cd..20495a05474b 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h @@ -245,7 +245,7 @@ struct AttentionBackwardKernel { static constexpr int64_t kWarpSize = 32; // If this is true, we store and accumulate dK/dV in RF - // rather than going back to gmem everytime + // rather than going back to gmem every time static constexpr bool kIsHalf = cutlass::sizeof_bits::value <= 16; static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI; static_assert( diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index d1101b6597a5..4b198f4d6d2d 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -57,28 +57,21 @@ namespace sdp { namespace { -// tracks whether we've set the default priority order once, to avoid setting -// it redundantly or overwriting a user-specified priority order -// when the priority order context manager is used before the default priority -// order is initialized the following happens: -// (1) the current priority order is queried -// (2) priority_order() is called, which initializes it to the default as init_ is false -// (3) the user-specified priority order is set -// (3.1) we are in the priority context... -// (3.2) we exit the priority context... -// (4) the previous priority order (default) is restored -bool priority_order_init_ = false; - // TODO(eqy): more benchmarking to determine whether this should include sm86/89 // Needs to be kept in-sync with test_fused_chocie in test_transformers.py bool check_prefer_cudnn_attention() { - static const bool prefer_cudnn = c10::utils::check_env("TORCH_CUDNN_SDPA_PREFERRED") == true; - if (!prefer_cudnn) { - return false; - } -#if (defined(CUDNN_VERSION) && (CUDNN_VERSION > 90000)) + // TODO(eqy): Re-enable by default after upgrading to a release later than 9.5.0 + // see context: https://github.com/pytorch/pytorch/issues/138340 + // return false; +#if defined(CUDNN_VERSION) + +#if CUDNN_VERSION > 90000 auto dprops = at::cuda::getCurrentDeviceProperties(); - return dprops->major >= 9 && !dprops->minor; + return dprops->major >= 9; +#else + return false; +#endif + #else return false; #endif @@ -86,16 +79,6 @@ bool check_prefer_cudnn_attention() { // flash_attention V2 is universally faster than efficient_attention and Math std::array priority_order(sdp_params const& params) { - if (!priority_order_init_) { - priority_order_init_ = true; - if (check_prefer_cudnn_attention()) { - const std::vector cudnn_order = {static_cast(at::SDPBackend::cudnn_attention), - static_cast(at::SDPBackend::flash_attention), - static_cast(at::SDPBackend::efficient_attention), - static_cast(at::SDPBackend::math)}; - at::globalContext().setSDPPriorityOrder(cudnn_order); - } - } return at::globalContext().sDPPriorityOrder(); } @@ -395,7 +378,7 @@ bool check_flash_causal_non_square_seqlens(sdp_params const& params, bool debug) bool check_all_tensors_on_device(sdp_params const& params, bool debug) { // Check that all tensors are on the GPU device - // This should be handled by the stub dispatch, but whe call can_use_*_attention + // This should be handled by the stub dispatch, but we call can_use_*_attention // directly from python we need to ensure that the tensors are on cuda if (params.query.device().type() != at::DeviceType::CUDA) { if (debug) { @@ -431,7 +414,12 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) { return false; } auto head_dim_limit = 128; - // TODO(eqy): add head dim >= 256 cases once support is finalized + if (cudnn_version >= 90501) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + if (dprops->major == 9 && !dprops->minor) { + head_dim_limit = 256; + } + } if (d_qk > head_dim_limit || d_v > head_dim_limit) { if (debug) { TORCH_WARN("head_dim should be no more than ", head_dim_limit); @@ -465,15 +453,9 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) { return false; } } - if (s_k == 1) { - if (debug) { - TORCH_WARN_ONCE("cudnn SDPA does not support key/value sequence length 1."); - } - return false; - } - if (s_q == 1 && params.dropout != 0.0) { + if (s_q == 1 || s_k == 1) { if (debug) { - TORCH_WARN_ONCE("cudnn SDPA does not support query sequence length 1 with dropout."); + TORCH_WARN_ONCE("cudnn SDPA does not support sequence length 1."); } return false; } @@ -581,9 +563,9 @@ bool check_for_nested_inputs(sdp_params const& params, bool debug) { const auto dprop = at::cuda::getCurrentDeviceProperties(); // Check that the input is nested - if ((dprop->major == 9 || dprop->major == 10) && has_for_nested_inputs(params)) { + if (dprop->major != 9 && has_for_nested_inputs(params)) { if (debug) { - TORCH_WARN("cuDNN SDPA supports nested tensors on SM 9.0, SM 10.0."); + TORCH_WARN("CuDNN SDPA supports nested tensors on SM 9.0."); } return false; } @@ -607,7 +589,7 @@ bool check_runtime_disabled_cudnn(sdp_params const& params, bool debug) { // sdp kernels if (!at::globalContext().userEnabledCuDNNSDP()) { if (debug) { - TORCH_WARN("cuDNN attention has been runtime disabled."); + TORCH_WARN("CuDNN attention has been runtime disabled."); } return false; } @@ -638,7 +620,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { #endif #if defined(CUDNN_VERSION) && CUDNN_VERSION < 90000 if (debug) { - TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use cuDNN Attention (< v9.0.0)"); + TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use CuDNN Attention (< v9.0.0)"); } return false; #endif @@ -648,8 +630,10 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { c10::array_of( check_runtime_disabled_cudnn, check_for_nested_inputs, + check_nonzero_sequence_lengths_dense, check_all_tensors_on_device, check_tensor_shapes, + check_cudnn_tensor_shapes, check_cudnn_deterministic, check_dtypes_low_precision, check_attn_mask_shape, @@ -662,10 +646,8 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { } constexpr auto dense_constraints = c10::array_of( - check_nonzero_sequence_lengths_dense, check_last_dim_stride_equals_1_dense, - check_batch_size_and_num_heads_dense, - check_cudnn_tensor_shapes + check_batch_size_and_num_heads_dense ); if (has_only_dense_inputs(params)) { @@ -882,7 +864,7 @@ SDPBackend select_sdp_backend(sdp_params const& kernel_params) { sdp::can_use_mem_efficient_attention(kernel_params, print_debug); TORCH_WARN("Flash attention kernel not used because:"); sdp::can_use_flash_attention(kernel_params, print_debug); - TORCH_WARN("cuDNN attention kernel not used because:"); + TORCH_WARN("CuDNN attention kernel not used because:"); sdp::can_use_cudnn_attention(kernel_params, print_debug); TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.") return SDPBackend::error; diff --git a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h index b38122248db8..aedb205e5710 100644 --- a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h +++ b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h @@ -82,7 +82,7 @@ aotriton::TensorView mk_aotensor(const at::Tensor& q, std::string_view ten { const auto strides = q.strides(); int real_rank = strides.size(); - if (real_rank != Rank) { // Lazy convertion of tensor_name + if (real_rank != Rank) { // Lazy conversion of tensor_name TORCH_CHECK(false, std::string(tensor_name) + "'s rank should be " + std::to_string(Rank) + " but is " + std::to_string(real_rank)); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index 1908096e2f6f..05523f75caa4 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -401,7 +401,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot CHECK_SHAPE(cu_seqlens_k, batch_size + 1); // AOTriton's varlen API needs input shapes be - // (1, num_heads, total sequence lenght, head dimension) + // (1, num_heads, total sequence length, head dimension) at::Tensor q_padded, k_padded, v_padded; at::Tensor out, out_padded; q_padded = q.unsqueeze(0).transpose(1, 2); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip index 20ad315d3025..ece6f29877ab 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip @@ -209,7 +209,7 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads const int total_q = q.size(0); const int total_k = k.size(0); - TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); diff --git a/aten/src/ATen/native/transformers/sdp_utils_cpp.h b/aten/src/ATen/native/transformers/sdp_utils_cpp.h index aa5c2b6cdd64..c63ca928613e 100644 --- a/aten/src/ATen/native/transformers/sdp_utils_cpp.h +++ b/aten/src/ATen/native/transformers/sdp_utils_cpp.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -503,17 +504,27 @@ inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool if (ignore_singleton_dim){ qkv_strides_equal_1 = qkv_strides_equal_1 || params.query.sym_size(-1) == 1; } - if (!qkv_strides_equal_1) { + bool is_cpu = params.query.device().type() == c10::DeviceType::CPU; + bool mask_stride_equal_1 = params.attn_mask.has_value() + ? params.attn_mask.value().sym_stride(-1) == 1 + : true; + bool mask_stride_valid = is_cpu ? true : mask_stride_equal_1; + if (!(qkv_strides_equal_1 && mask_stride_valid)) { if (debug) { - TORCH_WARN( - "All fused kernels require the last dimension of the input to have stride 1. ", - "Got Query.stride(-1): ", - params.query.sym_stride(-1), - ", Key.stride(-1): ", - params.key.sym_stride(-1), - ", Value.stride(-1): ", - params.value.sym_stride(-1), - " instead."); + std::ostringstream message; + message + << "All fused kernels require the last dimension of the input to have stride 1. "; + message << "Got Query.stride(-1): " << params.query.sym_stride(-1) + << ", Key.stride(-1): " << params.key.sym_stride(-1) + << ", Value.stride(-1): " << params.value.sym_stride(-1); + + if (params.attn_mask.has_value()) { + message + << ", Attn_mask.stride(-1): " + << params.attn_mask.value().sym_stride(-1) + << " (GPU backends require attn_mask's last dimension to have stride 1 while the CPU does not)."; + } + TORCH_WARN(message.str()); } return false; diff --git a/aten/src/ATen/native/utils/ParamsHash.h b/aten/src/ATen/native/utils/ParamsHash.h index 6b7894cb8549..4c9d97328ad6 100644 --- a/aten/src/ATen/native/utils/ParamsHash.h +++ b/aten/src/ATen/native/utils/ParamsHash.h @@ -41,7 +41,7 @@ struct ParamsEqual { }; // Provide explicit byte-for-byte constructors to avoid uwittingly leaving -// padding bytes unitialized (e.g., when passing Params by value) +// padding bytes uninitialized (e.g., when passing Params by value) template struct ParamsWrapper { T pod; diff --git a/aten/src/ATen/native/vulkan/api/Types.h b/aten/src/ATen/native/vulkan/api/Types.h index 548703aa8a95..1202a3bd7393 100644 --- a/aten/src/ATen/native/vulkan/api/Types.h +++ b/aten/src/ATen/native/vulkan/api/Types.h @@ -71,7 +71,7 @@ inline VkFormat to_vkformat(const ScalarType t) { /* * Given a `VkFormat`, return the `ScalarType` that best represents the data - * type of invidivual elements in an image texture of the `VkFormat`. Note that + * type of individual elements in an image texture of the `VkFormat`. Note that * this mapping is different from the `to_vkformat()` function, since different * `ScalarType`s may use the same `VkFormat`. */ diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d.glsl index d5fe3c232e44..47a2630aaafb 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d.glsl @@ -75,7 +75,7 @@ void main() { // During prepacking, the weight tensor was rearranged in order to optimize // for data access linearity in this shader. Therefore we need to adjust the // canonical coordinates to the corresponding index in the rearranged weight - // tensor. the x coordinate is multipled by 4 since each group of 4 channels + // tensor. the x coordinate is multiplied by 4 since each group of 4 channels // is folded into the X axis. The y coordinate is offset based on the z // coordinate because the 2D planes were stacked atop each other vertically. kstart.x *= 4; diff --git a/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl b/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl index c4728a9bb94e..d4188d658059 100644 --- a/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl +++ b/aten/src/ATen/native/vulkan/glsl/conv2d_pw.glsl @@ -39,7 +39,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; /* * Computes a 2D pointwise convolution of a 2x2 output tile. Calculating an * output tile for pointwise convolution is more efficient because the kernel - * size is only 1x1, making it much easier to re-use loaded texels from uKernel. + * size is only 1x1, making it much easier to reuse loaded texels from uKernel. */ void main() { const ivec3 gpos = ivec3(gl_GlobalInvocationID); diff --git a/aten/src/ATen/native/vulkan/glsl/image_to_nchw_uint.glsl b/aten/src/ATen/native/vulkan/glsl/image_to_nchw_uint.glsl index 8080db6120aa..1f66a5fe1915 100644 --- a/aten/src/ATen/native/vulkan/glsl/image_to_nchw_uint.glsl +++ b/aten/src/ATen/native/vulkan/glsl/image_to_nchw_uint.glsl @@ -57,7 +57,7 @@ void main() { // out CxHxW plane. ivec4 c_index = pos_in_batch / uBlock.in_extents.w; - // we devide pos_in_batch by HxW, to compute the channel index + // we divide pos_in_batch by HxW, to compute the channel index ivec4 pos_in_hw = pos_in_batch % uBlock.in_extents.w; // we compute the reminder mod HxW, to find the positions in the flatten diff --git a/aten/src/ATen/native/vulkan/glsl/indexing.h b/aten/src/ATen/native/vulkan/glsl/indexing.h index 2bda5a236240..c34ce25001ef 100644 --- a/aten/src/ATen/native/vulkan/glsl/indexing.h +++ b/aten/src/ATen/native/vulkan/glsl/indexing.h @@ -1,12 +1,12 @@ /* - * Computes a 4D tensor co-ordinate from a linearized index + * Computes a 4D tensor coordinate from a linearized index */ uvec4 idx_to_coord(const uint idx, const uvec4 strides, const uvec4 sizes) { return ivec4(mod(idx / strides, sizes)); } /* - * Computes a linearized index from a 4D tensor co-ordinate + * Computes a linearized index from a 4D tensor coordinate */ uint coord_to_idx(const uvec4 coord, const uvec4 strides) { return int(dot(coord * strides, ivec4(1))); diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl index 0b4ee355a064..bc13655d01e0 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl @@ -96,7 +96,7 @@ void main() { // During prepacking, the weight tensor was rearranged in order to optimize // for data access linearity in this shader. Therefore we need to adjust the // canonical coordinates to the corresponding index in the rearranged weight - // tensor. the x coordinate is multipled by 4 since each group of 4 channels + // tensor. the x coordinate is multiplied by 4 since each group of 4 channels // is folded into the X axis. The y coordinate is offset based on the z // coordinate because the 2D planes were stacked atop each other vertically. kstart.x *= 4; diff --git a/aten/src/ATen/native/vulkan/ops/Tile.cpp b/aten/src/ATen/native/vulkan/ops/Tile.cpp index 2ea62e909119..d39fd951106c 100644 --- a/aten/src/ATen/native/vulkan/ops/Tile.cpp +++ b/aten/src/ATen/native/vulkan/ops/Tile.cpp @@ -18,7 +18,7 @@ namespace { using namespace api::utils; Tensor tile(const Tensor& self, const IntArrayRef repeats) { - // If self.size() > len(reps), reps is promoted to self.size() by pre-pending + // If self.size() > len(reps), reps is promoted to self.size() by prepending // 1’s to it to keep the same behaviour as `numpy.tile`. // Thus for a tensor of shape (2, 3, 4, 5), a dims of (2, 2) is treated // as (1, 1, 2, 2). diff --git a/aten/src/ATen/nnapi/nnapi_bind.cpp b/aten/src/ATen/nnapi/nnapi_bind.cpp index 120c62cd4ab9..8f40ee404568 100644 --- a/aten/src/ATen/nnapi/nnapi_bind.cpp +++ b/aten/src/ATen/nnapi/nnapi_bind.cpp @@ -26,7 +26,7 @@ static void load_platform_library() { (void)run_once; } -// NnapiCompilation functon definitions: +// NnapiCompilation function definitions: // Could possibly call load_platform_library in constructor, but error reporting // can be complicated if the constructor is called during model loading. diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index 29fbc8270a45..8ec70a1682f3 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -666,7 +666,7 @@ void record_function_with_scope_and_debug_handle( guard, fn, debug_handle, inputs, ##__VA_ARGS__); \ } -// Helper macros to record LITE INTERPETER scope events with debug handles +// Helper macros to record LITE INTERPRETER scope events with debug handles #define RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS( \ fn, debug_handle, inputs) \ RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \ diff --git a/aten/src/ATen/templates/Functions.cpp b/aten/src/ATen/templates/Functions.cpp index e111a88b3309..f210402e543a 100644 --- a/aten/src/ATen/templates/Functions.cpp +++ b/aten/src/ATen/templates/Functions.cpp @@ -64,7 +64,7 @@ Tensor TensorMaker::make_tensor() { if (strides_) { auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize); if (storage_offset_) { - storage_size += storage_offset_.value(); + storage_size += storage_offset_.value() * itemsize; } return storage_size; } @@ -75,7 +75,7 @@ Tensor TensorMaker::make_tensor() { } auto storage_size = size * itemsize; if (storage_offset_) { - storage_size += storage_offset_.value(); + storage_size += storage_offset_.value() * itemsize; } return storage_size; } diff --git a/aten/src/ATen/templates/RegisterDispatchKey.cpp b/aten/src/ATen/templates/RegisterDispatchKey.cpp index 968729f85267..39c85b00d7a1 100644 --- a/aten/src/ATen/templates/RegisterDispatchKey.cpp +++ b/aten/src/ATen/templates/RegisterDispatchKey.cpp @@ -5,13 +5,11 @@ // NOTE: This condition is true for all PyTorch internal libraries, it // just excludes external projects such as torch_xla which -// re-use some of the PyTorch codegen machinery. +// reuse some of the PyTorch codegen machinery. #if defined(CAFFE2_BUILD_MAIN_LIB) || \ defined(TORCH_CUDA_BUILD_MAIN_LIB) || \ defined(TORCH_HIP_BUILD_MAIN_LIB) || \ - defined(TORCH_XPU_BUILD_MAIN_LIB) || \ - defined(TORCH_CUDA_CU_BUILD_MAIN_LIB) || \ - defined(TORCH_CUDA_CPP_BUILD_MAIN_LIB) + defined(TORCH_XPU_BUILD_MAIN_LIB) #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #endif diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index 050d882f42bf..8ae2dee1ce50 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -491,7 +491,7 @@ class TORCH_API Tensor: public TensorBase { "attribute won't be populated during autograd.backward(). If you indeed want the .grad " "field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. " "If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor " - "instead. See github.com/pytorch/pytorch/pull/30531 for more informations."); + "instead. See github.com/pytorch/pytorch/pull/30531 for more information."); } return maybe_grad; } diff --git a/aten/src/ATen/test/basic.cpp b/aten/src/ATen/test/basic.cpp index 8dd2e59ce2dd..0937de455282 100644 --- a/aten/src/ATen/test/basic.cpp +++ b/aten/src/ATen/test/basic.cpp @@ -517,3 +517,21 @@ TEST(BasicTest, BasicStdTestCPU) { t3.join(); t4.join(); } + +TEST(BasicTest, TestForBlobResizeCPU) { + // Checks that for_blob can correctly create tensors with non-empty offset and resize them + std::array storage; + std::iota(storage.begin(), storage.end(), 1); + auto t = at::for_blob(storage.data(), {3,}).storage_offset(3).options(c10::TensorOptions(kInt)).make_tensor(); + auto te = *at::expand_size(t, {3, 3}); + ASSERT_EQ(te[1][1].item(), 5); +} + +TEST(BasicTest, TestForBlobStridesResizeCPU) { + // Checks that for_blob can correctly create tensors with non-empty offset and resize them + std::array storage; + std::iota(storage.begin(), storage.end(), 1); + auto t = at::for_blob(storage.data(), {3,}).strides({1,}).storage_offset(3).options(c10::TensorOptions(kInt)).make_tensor(); + auto te = *at::expand_size(t, {3, 3}); + ASSERT_EQ(te[1][1].item(), 5); +} diff --git a/aten/src/ATen/test/cpu_profiling_allocator_test.cpp b/aten/src/ATen/test/cpu_profiling_allocator_test.cpp index c390305e2051..15220e58e248 100644 --- a/aten/src/ATen/test/cpu_profiling_allocator_test.cpp +++ b/aten/src/ATen/test/cpu_profiling_allocator_test.cpp @@ -199,7 +199,7 @@ int main(int argc, char* argv[]) { #ifdef C10_MOBILE // Need to disable mkldnn for this test since it allocated memory - // via raw_allocate inteface which requires context pointer and raw + // via raw_allocate interface which requires context pointer and raw // pointer to be the same. Tis is not true for mobile allocator. at::globalContext().setUserEnabledMkldnn(false); #endif diff --git a/aten/src/ATen/test/half_test.cpp b/aten/src/ATen/test/half_test.cpp index 900758233432..9e594196c692 100644 --- a/aten/src/ATen/test/half_test.cpp +++ b/aten/src/ATen/test/half_test.cpp @@ -25,7 +25,7 @@ TEST(TestHalf, Arithmetic) { ASSERT_EQ(one + one, 2); } -TEST(TestHalf, Comparisions) { +TEST(TestHalf, Comparisons) { Half zero = 0; Half one = 1; ASSERT_LT(zero, one); diff --git a/aten/src/ATen/test/undefined_tensor_test.cpp b/aten/src/ATen/test/undefined_tensor_test.cpp index 91777c3a05c7..ec6997fae9b0 100644 --- a/aten/src/ATen/test/undefined_tensor_test.cpp +++ b/aten/src/ATen/test/undefined_tensor_test.cpp @@ -9,7 +9,7 @@ using namespace at; TEST(TestUndefined, UndefinedTest) { manual_seed(123); - // mainly test ops on undefined tensors don't segfault and give a reasonable errror message. + // mainly test ops on undefined tensors don't segfault and give a reasonable error message. Tensor und; Tensor ft = ones({1}, CPU(kFloat)); diff --git a/aten/src/ATen/test/vec_test_all_types.cpp b/aten/src/ATen/test/vec_test_all_types.cpp index a9b5a70f1de9..b7b756f74ba1 100644 --- a/aten/src/ATen/test/vec_test_all_types.cpp +++ b/aten/src/ATen/test/vec_test_all_types.cpp @@ -5,7 +5,7 @@ namespace { template class Memory : public ::testing::Test {}; template - class Arithmetics : public ::testing::Test {}; + class Arithmetic : public ::testing::Test {}; template class Comparison : public ::testing::Test {}; template @@ -92,7 +92,7 @@ namespace { using ComplexTypes = ::testing::Types; using ReducedFloatTestedTypes = ::testing::Types; TYPED_TEST_SUITE(Memory, ALLTestedTypes); - TYPED_TEST_SUITE(Arithmetics, FloatIntTestedTypes); + TYPED_TEST_SUITE(Arithmetic, FloatIntTestedTypes); TYPED_TEST_SUITE(Comparison, RealFloatIntReducedFloatTestedTypes); TYPED_TEST_SUITE(Bitwise, FloatIntTestedTypes); TYPED_TEST_SUITE(MinMax, RealFloatIntTestedTypes); @@ -691,7 +691,7 @@ namespace { AssertVectorized(NAME_INFO(DeInterleave FirstHalf), std::get<0>(cc), vec::loadu(vals)).check(true); AssertVectorized(NAME_INFO(DeInterleave SecondHalf), std::get<1>(cc), vec::loadu(vals + vec::size())).check(true); } - TYPED_TEST(Arithmetics, Plus) { + TYPED_TEST(Arithmetic, Plus) { using vec = TypeParam; using VT = ValueType; test_binary( @@ -703,7 +703,7 @@ namespace { createDefaultBinaryTestCase(TestSeed()), RESOLVE_OVERLOAD(filter_add_overflow)); } - TYPED_TEST(Arithmetics, Minus) { + TYPED_TEST(Arithmetic, Minus) { using vec = TypeParam; using VT = ValueType; test_binary( @@ -715,7 +715,7 @@ namespace { createDefaultBinaryTestCase(TestSeed()), RESOLVE_OVERLOAD(filter_sub_overflow)); } - TYPED_TEST(Arithmetics, Multiplication) { + TYPED_TEST(Arithmetic, Multiplication) { using vec = TypeParam; test_binary( NAME_INFO(mult), @@ -724,7 +724,7 @@ namespace { createDefaultBinaryTestCase(TestSeed(), false, true), RESOLVE_OVERLOAD(filter_mult_overflow)); } - TYPED_TEST(Arithmetics, Division) { + TYPED_TEST(Arithmetic, Division) { using vec = TypeParam; TestSeed seed; test_binary( diff --git a/aten/src/ATen/test/vec_test_all_types.h b/aten/src/ATen/test/vec_test_all_types.h index f7062a3048df..f7206cc34097 100644 --- a/aten/src/ATen/test/vec_test_all_types.h +++ b/aten/src/ATen/test/vec_test_all_types.h @@ -531,7 +531,7 @@ template std::enable_if_t::value, void> filter_div_ub(T& val1, T& val2) { //missing - //at least consdier zero division + //at least consider zero division auto ret = std::abs(val2); if (ret == 0) { val2 = T(1, 2); @@ -1291,7 +1291,7 @@ std::enable_if_t>::value, Complex> local_multiply(Compl T y_real = y.real(); T y_imag = y.imag(); #if defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_ZVECTOR) - //check multiplication considerin swap and fma + //check multiplication considering swap and fma T rr = x_real * y_real; T ii = x_imag * y_real; T neg_imag = -y_imag; @@ -1362,7 +1362,7 @@ std::enable_if_t>::value, Complex> local_division(Compl return Complex(rr, ii); #else /* defined(CPU_CAPABILITY_ZVECTOR) */ #if defined(CPU_CAPABILITY_VSX) - //check multiplication considerin swap and fma + //check multiplication considering swap and fma T rr = x_real * y_real; T ii = x_imag * y_real; T neg_imag = -y_imag; diff --git a/aten/src/ATen/test/verify_api_visibility.cpp b/aten/src/ATen/test/verify_api_visibility.cpp index 5878ed352e5b..c6d2fcc6fb86 100644 --- a/aten/src/ATen/test/verify_api_visibility.cpp +++ b/aten/src/ATen/test/verify_api_visibility.cpp @@ -20,4 +20,8 @@ #error "CAFFE2_STATIC_LINK_CUDA should not be visible in public headers" #endif -auto main() -> int {} +#include + +TEST(VerifyApiVisibility, Test) { + ASSERT_EQ(1, 1); +} diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp index 1b4750b6c41e..263918af2662 100644 --- a/aten/src/ATen/test/vulkan_api_test.cpp +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -1232,7 +1232,7 @@ void test_matmul( } TEST_F(VulkanAPITest, DISABLED_matmul_3d_weight_vulkan) { - // This will call at::bmm. Will crash for unknow reason. + // This will call at::bmm. Will crash for unknown reason. const auto m1_cpu = at::rand({13, 23, 45}, at::device(at::kCPU).dtype(at::kFloat)); const auto m2_cpu = @@ -1241,7 +1241,7 @@ TEST_F(VulkanAPITest, DISABLED_matmul_3d_weight_vulkan) { } TEST_F(VulkanAPITest, DISABLED_matmul_3d_weight_cpu) { - // This will call at::bmm. Will crash for unknow reason. + // This will call at::bmm. Will crash for unknown reason. const auto m1_cpu = at::rand({13, 23, 45}, at::device(at::kCPU).dtype(at::kFloat)); const auto m2_cpu = @@ -2004,7 +2004,7 @@ TEST_F(VulkanAPITest, conv2d_pw_prepack_bc_medium) { 1); // groups } -// The followin 2 tests failed on Meta's CI when all tests are executed. Output +// The following 2 tests failed on Meta's CI when all tests are executed. Output // has lots of nan. Cause unknown. // When this test is run alone (with gtest_filter), it passes. // The test also passes with smaller planes, see "conv2d_pw_prepack_medium". @@ -5664,7 +5664,7 @@ TEST_F(VulkanAPITest, var_2d_unbiased) { test_var({3, 5}, {1}, true, true); test_var({3, 5}, {1}, true, false); - // inpu.dim() == dim_list.size(), only keepdim == true is supported + // input.dim() == dim_list.size(), only keepdim == true is supported test_var({3, 5}, {0, 1}, true, true); } @@ -5672,7 +5672,7 @@ TEST_F(VulkanAPITest, var_2d_biased) { test_var({3, 5}, {1}, false, true); test_var({3, 5}, {1}, false, false); - // inpu.dim() == dim_list.size(), only keepdim == true is supported + // input.dim() == dim_list.size(), only keepdim == true is supported test_var({3, 5}, {0, 1}, false, true); } @@ -7142,12 +7142,12 @@ TEST_F(VulkanAPITest, clone_success) { } TEST_F(VulkanAPITest, clone_invalidinputs_exceptions) { - // Act: Vulkan supports Preserve and Contiguous memory foramts + // Act: Vulkan supports Preserve and Contiguous memory formats EXPECT_THROW({ clone_test({2, 3, 5, 161}, c10::MemoryFormat::ChannelsLast); }, ::std::exception); - // Act: Vulkan supports Preserve and Contiguous memory foramts + // Act: Vulkan supports Preserve and Contiguous memory formats EXPECT_THROW({ clone_test({2, 3, 5, 161}, c10::MemoryFormat::ChannelsLast3d); }, ::std::exception); diff --git a/aten/src/ATen/test/vulkan_quantized_api_test.cpp b/aten/src/ATen/test/vulkan_quantized_api_test.cpp index 650afceb887c..2829aed94def 100644 --- a/aten/src/ATen/test/vulkan_quantized_api_test.cpp +++ b/aten/src/ATen/test/vulkan_quantized_api_test.cpp @@ -2116,7 +2116,7 @@ std::tuple produce_inputs_for_binary_op( input2_cpu = produce_random_tensor(input2_shape); if (compute_quantization_params) { - // compute appropiate scale and zero point for inputs + // compute appropriate scale and zero point for inputs const auto in1_quant_params = compute_quant_params(input1_cpu); in1_scale = std::get<0>(in1_quant_params); in1_zero_point = std::get<1>(in1_quant_params); @@ -2287,7 +2287,7 @@ void test_quantized_binary_op( apply_cpu_quantized_binary_op(op_name, input1_cpu_deq, input2_cpu_deq); if (compute_quantization_params || random_quantization_params) { - // compute appropiate scale and zero point for output + // compute appropriate scale and zero point for output const auto out_quant_params = compute_quant_params(output_cpu); out_scale = std::get<0>(out_quant_params); out_zero_point = std::get<1>(out_quant_params); @@ -2540,7 +2540,7 @@ void test_quantized_conv2d( bias_cpu = produce_random_tensor(bias_shape, 1.26, 5.97, 0.59); if (compute_quantization_params) { - // compute appropiate scale and zero point for input, weight and bias + // compute appropriate scale and zero point for input, weight and bias const auto in_quant_params = compute_quant_params(input_cpu, in_dtype); in_scale = std::get<0>(in_quant_params); in_zero_point = std::get<1>(in_quant_params); @@ -2624,7 +2624,7 @@ void test_quantized_conv2d( groups); if (compute_quantization_params || random_quantization_params) { - // compute appropiate scale and zero point for output + // compute appropriate scale and zero point for output const auto out_quant_params = compute_quant_params(output_cpu, out_dtype); out_scale = std::get<0>(out_quant_params); out_zero_point = std::get<1>(out_quant_params); @@ -3524,7 +3524,7 @@ TEST_F(VulkanAPITest, linear_4d_large) { test_quantized_linear({9, 13, 11, 17}, {23, 17}, {23}); } -// The following code is not directly releated to quantization. We put it here +// The following code is not directly related to quantization. We put it here // since we are not able to run this test on GH's CI: for some unknown reason, // we are not able to reference symbols in the vulkan directory, hence the build // on GH fails. Moving the test here so we are still able to run it on @@ -3566,7 +3566,7 @@ TEST_F(VulkanAPITest, extract_texel_test) { // is the channel count. // We always start a new batch on a new z. Hence, when c cannot be divided by // 4, there are some undefined values in the padding area. We use -1 to - // indicate that we are not performing comparsion on those values. + // indicate that we are not performing comparison on those values. std::tuple test_cases[]{ {{0, 0, 0}, {0, hw, 2 * hw, 3 * hw}}, {{1, 0, 0}, {1, hw + 1, 2 * hw + 1, 3 * hw + 1}}, @@ -3672,7 +3672,7 @@ TEST_F(VulkanAPITest, channel_to_width_packing_test) { at::Tensor output = at::native::vulkan::ops::convert(v_output); // This tensor will be width-packed. Meaning that each texel represent - // consecutive elements along the width dimension. The differece between + // consecutive elements along the width dimension. The difference between // consecutive texels is 1. std::tuple test_cases[]{ {{0, 0, 0}, {0, 1, 2, 3}}, diff --git a/aten/src/ATen/xpu/XPUEvent.h b/aten/src/ATen/xpu/XPUEvent.h index ededd6ebf4f1..19d42aae080f 100644 --- a/aten/src/ATen/xpu/XPUEvent.h +++ b/aten/src/ATen/xpu/XPUEvent.h @@ -12,7 +12,7 @@ namespace at::xpu { * must match the same device. * * Currently, XPUEvent does NOT support to export an inter-process event from - * another process via inter-process comunication(IPC). So it means that + * another process via inter-process communication(IPC). So it means that * inter-process communication for event handles between different processes is * not available. This could impact some applications that rely on cross-process * synchronization and communication. diff --git a/aten/src/README.md b/aten/src/README.md index 3127ed5c8c39..fa279c89d26c 100644 --- a/aten/src/README.md +++ b/aten/src/README.md @@ -8,7 +8,7 @@ multiple variants of the library, summarized here: * THC = TorcH Cuda * THCS = TorcH Cuda Sparse (now defunct) * THNN = TorcH Neural Network (now defunct) -* THS = TorcH Sparse (now defunct) +* THS = TorcH Sparse (now defunct) (You'll also see these abbreviations show up in symbol names.) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 1088634ce911..900a93c552b4 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -22,7 +22,7 @@ import time import weakref from contextlib import contextmanager -from typing import Any, NamedTuple, TYPE_CHECKING +from typing import Any, NamedTuple, Optional, overload, TYPE_CHECKING, TypeVar from unittest.mock import MagicMock import numpy as np @@ -54,6 +54,7 @@ from torch._inductor.utils import fresh_cache except ImportError: from _dynamo.utils import clone_inputs, graph_break_reasons + from _inductor.utils import fresh_cache import torch._functorch.config from torch._functorch.aot_autograd import set_model_name @@ -75,7 +76,10 @@ if TYPE_CHECKING: - from collections.abc import Mapping + from collections.abc import Sequence + +_D = TypeVar("_D", bound=dict[str, Any]) +_T = TypeVar("_T") log = logging.getLogger(__name__) @@ -766,7 +770,17 @@ def vary_batch(t: torch.Tensor, new_batch_size) -> torch.Tensor: return (time_total, result) if return_result else time_total -def _normalize_bench_inputs(example_inputs) -> tuple[tuple[Any], Mapping[str, Any]]: +@overload +def _normalize_bench_inputs(example_inputs: _D) -> tuple[tuple[()], _D]: ... + + +@overload +def _normalize_bench_inputs( + example_inputs: Sequence[_T], +) -> tuple[tuple[_T, ...], dict[str, Any]]: ... + + +def _normalize_bench_inputs(example_inputs): # NOTE(bowbao): For huggingface benchmark, example_inputs are formatted as dictionary, # and consumed like `model(**example_inputs)`. # For other benchmarks, example_inputs are formatted as tuple and consumed @@ -1671,7 +1685,7 @@ def __init__(self): self.grad_scaler = DummyGradScaler() self.autocast = contextlib.nullcontext self.autocast_arg = {} - self.optimizer = None + self.optimizer: Optional[torch.optim.Optimizer] = None self._args = None def setup_amp(self, current_device=None): diff --git a/benchmarks/dynamo/genai_layers/README.md b/benchmarks/dynamo/genai_layers/README.md new file mode 100644 index 000000000000..d2a11e0acc21 --- /dev/null +++ b/benchmarks/dynamo/genai_layers/README.md @@ -0,0 +1,23 @@ +# GenAI Kernel Benchmark + +This directory contains benchmarks for the GenAI kernels. It compares pytorch eager, pytorch compiler, quack, and liger. + + +## Setup + +Assuming pytorch is installed. + +``` +pip install -r requirements.txt +``` + +## Run + +``` + python benchmark.py --list # List all available benchmarks + python benchmark.py --all # Run all benchmarks + python benchmark.py cross_entropy_forward # Run specific benchmark + python benchmark.py softmax_forward softmax_backward # Run multiple benchmarks +``` + +Add `--visualize` to plot graph for the benchmark results. diff --git a/benchmarks/dynamo/genai_layers/benchmark.py b/benchmarks/dynamo/genai_layers/benchmark.py new file mode 100644 index 000000000000..70349ee44409 --- /dev/null +++ b/benchmarks/dynamo/genai_layers/benchmark.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +""" +Benchmark runner for various kernel implementations. + +This script provides a command-line interface to run benchmarks for different +kernel implementations including CrossEntropy, Softmax, RMSNorm, and LayerNorm +kernels in both forward and backward directions. +""" + +import argparse +import sys + +from kernels import ( + BenchmarkKernel, + CrossEntropyBackward, + CrossEntropyForward, + LayerNormBackward, + LayerNormForward, + RMSNormBackward, + RMSNormForward, + SoftmaxBackward, + SoftmaxForward, +) + +import torch + + +torch._dynamo.config.automatic_dynamic_shapes = False +# Needed since changing args to function causes recompiles +torch._dynamo.config.recompile_limit = 1000000 + + +# Registry of all available benchmarks +BENCHMARK_REGISTRY: dict[str, type[BenchmarkKernel]] = { + "cross_entropy_forward": CrossEntropyForward, + "cross_entropy_backward": CrossEntropyBackward, + "softmax_forward": SoftmaxForward, + "softmax_backward": SoftmaxBackward, + "rmsnorm_forward": RMSNormForward, + "rmsnorm_backward": RMSNormBackward, + "layernorm_forward": LayerNormForward, + "layernorm_backward": LayerNormBackward, +} + + +def show_environment_info(): + """Show environment information.""" + print("Environment information:") + print(f" Python version: {sys.version}") + print(f" PyTorch version: {torch.__version__}") + print(f" CUDA version: {torch.version.cuda}") + + +def list_benchmarks(): + """List all available benchmarks.""" + print(f"Available benchmarks: {list(BENCHMARK_REGISTRY.keys())}") + + +def run_benchmark( + benchmark_name: str, + should_visualize: bool = False, + compile_mode: str = "max-autotune-no-cudagraphs", +): + """Run a specific benchmark.""" + if benchmark_name not in BENCHMARK_REGISTRY: + print(f"Error: Unknown benchmark '{benchmark_name}'") + print("Use --list to see available benchmarks") + return False + + print(f"Running benchmark: {benchmark_name}") + print(f"Torch compile mode: {compile_mode}") + print("=" * 60) + + benchmark_class = BENCHMARK_REGISTRY[benchmark_name] + benchmark = benchmark_class(compile_mode) + benchmark.benchmark() + if should_visualize: + benchmark.visualize() + + return True + + +def run_all_benchmarks(should_visualize: bool = False, compile_mode: str = "default"): + """Run all available benchmarks.""" + print("Running all benchmarks...") + print(f"Torch compile mode: {compile_mode}") + print("=" * 60) + + for name, cls in BENCHMARK_REGISTRY.items(): + print(f"\n{'=' * 20} {name.upper()} {'=' * 20}") + benchmark = cls(compile_mode) + benchmark.benchmark() + if should_visualize: + benchmark.visualize() + print() + + +def main(): + show_environment_info() + + parser = argparse.ArgumentParser( + description="Benchmark runner for kernel implementations", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python benchmark.py --list # List all available benchmarks + python benchmark.py --all # Run all benchmarks + python benchmark.py cross_entropy_forward # Run specific benchmark + python benchmark.py softmax_forward softmax_backward # Run multiple benchmarks + """, + ) + + parser.add_argument( + "benchmarks", + nargs="*", + help="Names of benchmarks to run (use --list to see available options)", + ) + + parser.add_argument( + "--list", action="store_true", help="List all available benchmarks" + ) + + parser.add_argument( + "--all", action="store_true", help="Run all available benchmarks" + ) + + parser.add_argument( + "--visualize", + action="store_true", + help="Visualize results after running benchmarks", + ) + + parser.add_argument( + "--compile-mode", + choices=["default", "max-autotune-no-cudagraphs"], + default="max-autotune-no-cudagraphs", + help="Torch compile mode to use (default: default)", + ) + + args = parser.parse_args() + + # Handle list option + if args.list: + list_benchmarks() + return + + # Handle all option + if args.all: + run_all_benchmarks(args.visualize, args.compile_mode) + return + + # Handle specific benchmarks + if not args.benchmarks: + print("Error: No benchmarks specified") + print("Use --list to see available benchmarks or --all to run all benchmarks") + parser.print_help() + sys.exit(1) + + for benchmark_name in args.benchmarks: + run_benchmark(benchmark_name, args.visualize, args.compile_mode) + print() # Add spacing between benchmarks + + +if __name__ == "__main__": + main() diff --git a/benchmarks/dynamo/genai_layers/kernels.py b/benchmarks/dynamo/genai_layers/kernels.py new file mode 100644 index 000000000000..ee79f02761ed --- /dev/null +++ b/benchmarks/dynamo/genai_layers/kernels.py @@ -0,0 +1,639 @@ +from typing import Any + +import cutlass +import cutlass.torch as cutlass_torch +from utils import BenchmarkKernel + +import torch +import torch.nn.functional as F + + +class CrossEntropyForward(BenchmarkKernel): + def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): + super().__init__(compile_mode) + self.available_backends = ["eager", "compiled", "quack", "liger"] + + def get_shapes(self) -> tuple[tuple[int, ...], ...]: + return ( + (32768, 256), + (32768, 512), + (32768, 1024), + (32768, 2048), + (32768, 4096), + (32768, 8192), + (32768, 16384), + (32768, 32768), + (32768, 65536), + (16384, 131072), + (8192, 262144), + ) + + def get_memory_bytes(self, args, kwargs) -> int: + # Read x (M*N elements) + read target (M elements) + write loss (M elements) + x, target = args + M, N = x.shape + dtype = x.dtype + return (M * N + M + M) * dtype.itemsize + + def eager(self, args, kwargs=None) -> Any: + assert kwargs is None + x, target = args + return lambda: F.cross_entropy(x, target, reduction="none") + + def compiled(self, args, kwargs=None) -> Any: + assert kwargs is None + x, target = args + + # Mark batch size as dynamic for realistic workload + torch._dynamo.mark_dynamic(x, 0) + torch._dynamo.mark_dynamic(target, 0) + + # Need `lambda` otherwise torch.compile will not trace the function. + # More discussion: https://github.com/pytorch/pytorch/issues/158455 + compiled_cross_entropy = torch.compile( + lambda x, target: F.cross_entropy(x, target, reduction="none"), + mode=self.compile_mode, + fullgraph=True, + ) + return lambda: compiled_cross_entropy(x, target) + + def quack(self, args, kwargs=None) -> Any: + assert kwargs is None + x, target = args + from quack.cross_entropy import _cross_entropy + + return lambda: _cross_entropy(x, target) + + def liger(self, args, kwargs=None) -> Any: + assert kwargs is None + from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss + + x, target = args + cross_entropy = LigerCrossEntropyLoss(reduction="none") + return lambda: cross_entropy(x, target) + + def benchmark(self): + for M, N in self.get_shapes(): + print(f"\n Tensor dimensions: [{M}, {N}]") + # quack requires cutlass dtype + torch_dtype = cutlass_torch.dtype(cutlass.BFloat16) + x = 0.1 * torch.randn(M, N, device="cuda", dtype=torch_dtype) + target = torch.randint(0, N, (M,), device="cuda", dtype=torch.int64) + self.benchmark_single_shape((x, target), setting=f"shape: [{M}, {N}]") + + def check_accuracy(self, args, kwargs) -> None: + res = {} + for backend in self.available_backends: + args_ref, kwargs_ref = self.clone_inputs(args, kwargs) + res[backend] = getattr(self, backend)(args_ref, kwargs_ref)() + gold = res["eager"] + for backend in self.available_backends: + if backend == "eager": + continue + if backend == "quack": + # quack's cross_entropy only returns float32 loss output. + # Need to convert it to the same dtype as gold for comparison. + res[backend] = res[backend].to(gold.dtype) + try: + torch.testing.assert_close(res[backend], gold) + print( + f"Accuracy check \033[92m✓ succeed\033[0m for {backend} backend on {self.name} kernel" + ) + except Exception as e: + print( + f"Accuracy check \033[91m✗ failed\033[0m for {backend} backend on {self.name} kernel. Error {e}" + ) + + +class CrossEntropyBackward(BenchmarkKernel): + def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): + super().__init__(compile_mode) + self.available_backends = ["eager", "compiled", "quack", "liger"] + + def get_shapes(self) -> tuple[tuple[int, ...], ...]: + return ( + (32768, 256), + (32768, 512), + (32768, 1024), + (32768, 2048), + (32768, 4096), + (32768, 8192), + (32768, 16384), + (32768, 32768), + (32768, 65536), + (16384, 131072), + (8192, 262144), + ) + + def get_memory_bytes(self, args, kwargs) -> int: + # Read x (M*N elements) + read target (M elements) + read dloss (M elements) + write grad(M*N elements) + x, target, dloss = args + # Memory ba + M, N = x.shape + return ( + 2 * M * N * x.dtype.itemsize + + M * target.dtype.itemsize + + M * dloss.dtype.itemsize + ) + + def eager(self, args, kwargs=None) -> Any: + assert kwargs is None + x, target, dloss = args + loss = F.cross_entropy(x, target, reduction="none") + return lambda: torch.autograd.grad( + loss, x, grad_outputs=dloss, retain_graph=True + ) + + def compiled(self, args, kwargs=None) -> Any: + assert kwargs is None + x, target, dloss = args + + compiled_cross_entropy = torch.compile( + lambda x, target: F.cross_entropy(x, target, reduction="none"), + mode=self.compile_mode, + fullgraph=True, + ) + loss = compiled_cross_entropy(x, target) + return lambda: torch.autograd.grad( + loss, x, grad_outputs=dloss, retain_graph=True + ) + + def quack(self, args, kwargs=None) -> Any: + from quack.cross_entropy import cross_entropy + + assert kwargs is None + x, target, dloss = args + loss = cross_entropy(x, target) + return lambda: torch.autograd.grad( + loss, x, grad_outputs=dloss, retain_graph=True + ) + + def liger(self, args, kwargs=None) -> Any: + assert kwargs is None + from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss + + x, target, dloss = args + cross_entropy = LigerCrossEntropyLoss(reduction="none") + loss = cross_entropy(x, target) + return lambda: torch.autograd.grad( + loss, x, grad_outputs=dloss, retain_graph=True + ) + + def benchmark(self): + for M, N in self.get_shapes(): + print(f"Tensor dimensions: [{M}, {N}]") + torch_dtype = cutlass_torch.dtype(cutlass.BFloat16) + x = 0.1 * torch.randn( + M, N, device="cuda", dtype=torch_dtype, requires_grad=True + ) + target = torch.randint(0, N, (M,), device="cuda", dtype=torch.int64) + dloss = torch.randn(M, device="cuda", dtype=torch.float32) + self.benchmark_single_shape( + (x, target, dloss), setting=f"shape: [{M}, {N}]" + ) + + +class SoftmaxForward(BenchmarkKernel): + def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): + super().__init__(compile_mode) + self.available_backends = ["eager", "compiled", "quack", "liger"] + + def get_shapes(self) -> tuple[tuple[int, ...], ...]: + return ( + (32768, 256), + (32768, 512), + (32768, 1024), + (32768, 2048), + (32768, 4096), + (32768, 8192), + (32768, 16384), + (32768, 32768), + (32768, 65536), + (16384, 131072), + (8192, 262144), + ) + + def get_memory_bytes(self, args, kwargs) -> int: + (x,) = args + M, N = x.shape + return 2 * M * N * x.dtype.itemsize + + def eager(self, args, kwargs=None) -> Any: + assert kwargs is None + (x,) = args + return lambda: F.softmax(x, dim=-1) + + def compiled(self, args, kwargs=None) -> Any: + assert kwargs is None + (x,) = args + + # Mark batch size as dynamic for realistic workload + torch._dynamo.mark_dynamic(x, 0) + + compiled_softmax = torch.compile( + lambda x: F.softmax(x, dim=-1), mode=self.compile_mode, fullgraph=True + ) + return lambda: compiled_softmax(x) + + def quack(self, args, kwargs=None) -> Any: + from quack.softmax import softmax + + assert kwargs is None + (x,) = args + return lambda: softmax(x) + + def liger(self, args, kwargs=None) -> Any: + from liger_kernel.transformers.softmax import LigerSoftmax + + assert kwargs is None + (x,) = args + softmax = LigerSoftmax().to("cuda") + return lambda: softmax(x) + + def benchmark(self): + for M, N in self.get_shapes(): + print(f"Tensor dimensions: [{M}, {N}]") + torch_dtype = cutlass_torch.dtype(cutlass.BFloat16) + x = 0.1 * torch.randn(M, N, device="cuda", dtype=torch_dtype) + self.benchmark_single_shape((x,), setting=f"shape: [{M}, {N}]") + + +class SoftmaxBackward(BenchmarkKernel): + def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): + super().__init__(compile_mode) + self.available_backends = ["eager", "compiled", "quack", "liger"] + + def get_shapes(self) -> tuple[tuple[int, ...], ...]: + return ( + (32768, 256), + (32768, 512), + (32768, 1024), + (32768, 2048), + (32768, 4096), + (32768, 8192), + (32768, 16384), + (32768, 32768), + (32768, 65536), + (16384, 131072), + (8192, 262144), + ) + + def get_memory_bytes(self, args, kwargs) -> int: + # Memory: read dy and y, write ax backward + x, dy = args + M, N = x.shape + return 3 * M * N * x.dtype.itemsize + + def eager(self, args, kwargs=None) -> Any: + assert kwargs is None + x, dy = args + y = F.softmax(x, dim=-1) + return lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True) + + def compiled(self, args, kwargs=None) -> Any: + assert kwargs is None + x, dy = args + compiled_softmax = torch.compile( + lambda x: F.softmax(x, dim=-1), mode=self.compile_mode, fullgraph=True + ) + y = compiled_softmax(x) + return lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True) + + def quack(self, args, kwargs=None) -> Any: + from quack.softmax import softmax + + assert kwargs is None + x, dy = args + + y = softmax(x) + return lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True) + + def liger(self, args, kwargs=None) -> Any: + from liger_kernel.transformers.softmax import LigerSoftmax + + assert kwargs is None + x, dy = args + softmax = LigerSoftmax().to("cuda") + y = softmax(x) + return lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True) + + def benchmark(self): + for M, N in self.get_shapes(): + print(f"Tensor dimensions: [{M}, {N}]") + torch_dtype = cutlass_torch.dtype(cutlass.BFloat16) + x = 0.1 * torch.randn( + M, N, device="cuda", dtype=torch_dtype, requires_grad=True + ) + dy = torch.randn(M, N, device="cuda", dtype=torch_dtype) + self.benchmark_single_shape((x, dy), setting=f"shape: [{M}, {N}]") + + +class RMSNormForward(BenchmarkKernel): + def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): + super().__init__(compile_mode) + self.available_backends = ["eager", "compiled", "quack", "liger"] + + def get_shapes(self) -> tuple[tuple[int, ...], ...]: + return ( + (32768, 256), + (32768, 512), + (32768, 1024), + (32768, 2048), + (32768, 4096), + (32768, 8192), + (32768, 16384), + (32768, 32768), + (32768, 65536), + (16384, 131072), + (8192, 262144), + ) + + def get_memory_bytes(self, args, kwargs) -> int: + x, w = args + M, N = x.shape + return 2 * M * N * x.dtype.itemsize + N * w.dtype.itemsize + + def rms_norm_ref(self, x, w): + x_f32 = x.float() + return ( + x_f32 + * torch.rsqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) + * w + ).to(x.dtype) + + def eager(self, args, kwargs=None) -> Any: + assert kwargs is None + x, w = args + return lambda: self.rms_norm_ref(x, w) + + def compiled(self, args, kwargs=None) -> Any: + assert kwargs is None + x, w = args + + # Mark batch size as dynamic for realistic workload + torch._dynamo.mark_dynamic(x, 0) + + compiled_rms_norm = torch.compile( + self.rms_norm_ref, mode=self.compile_mode, fullgraph=True + ) + return lambda: compiled_rms_norm(x, w) + + def quack(self, args, kwargs=None) -> Any: + # Note: only supper weight with float32 dtype + from quack.rmsnorm import _rmsnorm_fwd + + x, w = args + return lambda: _rmsnorm_fwd(x, w, eps=1e-6) + + def liger(self, args, kwargs) -> Any: + from liger_kernel.transformers.rms_norm import LigerRMSNorm + + x, w = args + M, N = x.shape + liger_rmsnorm = LigerRMSNorm(hidden_size=N, eps=1e-6).cuda() + liger_rmsnorm.weight.data.copy_(w) + return lambda: liger_rmsnorm(x) + + def benchmark(self): + for M, N in self.get_shapes(): + print(f"Tensor dimensions: [{M}, {N}]") + torch_dtype = cutlass_torch.dtype(cutlass.BFloat16) + x = torch.randn(M, N, device="cuda", dtype=torch_dtype) + w = torch.randn(N, device="cuda", dtype=torch.float32) + self.benchmark_single_shape((x, w), setting=f"shape: [{M}, {N}]") + + +class RMSNormBackward(BenchmarkKernel): + def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): + super().__init__(compile_mode) + self.available_backends = ["eager", "compiled", "quack", "liger"] + + def get_shapes(self) -> tuple[tuple[int, ...], ...]: + # TODO: OOM for (32768, 65536) on h100 + return ( + (32768, 256), + (32768, 512), + (32768, 1024), + (32768, 2048), + (32768, 4096), + (32768, 8192), + (32768, 16384), + (32768, 32768), + ) + + def get_memory_bytes(self, args, kwargs) -> int: + x, w, dy = args + # x, dy: [M, N], w: [N] + M, N = x.shape + # Read x, w, dy, write dx, dw + return 3 * M * N * x.dtype.itemsize + 2 * N * w.dtype.itemsize + + def rms_norm_ref(self, x, w): + x_f32 = x.float() + return ( + x_f32 + * torch.rsqrt(torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) + * w + ).to(x.dtype) + + def eager(self, args, kwargs=None) -> Any: + assert kwargs is None + x, w, dy = args + y = self.rms_norm_ref(x, w) + return lambda: torch.autograd.grad( + y, [x, w], grad_outputs=dy, retain_graph=True + ) + + def compiled(self, args, kwargs=None) -> Any: + assert kwargs is None + x, w, dy = args + y = torch.compile(self.rms_norm_ref, mode=self.compile_mode, fullgraph=True)( + x, w + ) + return lambda: torch.autograd.grad( + y, [x, w], grad_outputs=dy, retain_graph=True + ) + + def quack(self, args, kwargs=None) -> Any: + from quack.rmsnorm import _rmsnorm_backward + + ( + x, + w, + dy, + ) = args + M, N = x.shape + rstd = torch.randn(M, device="cuda", dtype=torch.float32) + return lambda: _rmsnorm_backward(x, w, dy, rstd) + + def liger(self, args, kwargs=None) -> Any: + from liger_kernel.transformers.rms_norm import LigerRMSNorm + + x, w, dy = args + M, N = x.shape + liger_rmsnorm = LigerRMSNorm(hidden_size=N, eps=1e-6).cuda() + liger_rmsnorm.weight.data.copy_(w) + y = liger_rmsnorm(x) + return lambda: torch.autograd.grad( + y, [x, liger_rmsnorm.weight], grad_outputs=dy, retain_graph=True + ) + + def benchmark(self): + for M, N in self.get_shapes(): + print(f"Tensor dimensions: [{M}, {N}]") + torch_dtype = cutlass_torch.dtype(cutlass.BFloat16) + x = torch.randn(M, N, device="cuda", dtype=torch_dtype, requires_grad=True) + w = torch.randn(N, device="cuda", dtype=torch.float32, requires_grad=True) + dy = torch.randn(M, N, device="cuda", dtype=torch_dtype) + self.benchmark_single_shape((x, w, dy), setting=f"shape: [{M}, {N}]") + + +class LayerNormForward(BenchmarkKernel): + def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): + super().__init__(compile_mode) + self.available_backends = ["eager", "compiled", "quack", "liger"] + + def get_shapes(self) -> tuple[tuple[int, ...], ...]: + # OOM for (16384, 131072) on h100 + return ( + (32768, 256), + (32768, 512), + (32768, 1024), + (32768, 2048), + (32768, 4096), + (32768, 8192), + (32768, 16384), + (32768, 32768), + (32768, 65536), + ) + + def get_memory_bytes(self, args, kwargs) -> int: + x, w = args + M, N = x.shape + # Read x ([M, N]), w ([N]), write y ([M, N]) + return 2 * M * N * x.dtype.itemsize + N * w.dtype.itemsize + + def layernorm_ref(self, x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6): + x_f32 = x.float() + return F.layer_norm(x_f32, w.shape, w, None, eps).to(x.dtype) + + def eager(self, args, kwargs=None) -> Any: + assert kwargs is None + x, w = args + return lambda: self.layernorm_ref(x, w) + + def compiled(self, args, kwargs=None) -> Any: + assert kwargs is None + x, w = args + + # Mark batch size as dynamic for realistic workload + torch._dynamo.mark_dynamic(x, 0) + + compiled_layernorm = torch.compile( + self.layernorm_ref, mode=self.compile_mode, fullgraph=True + ) + return lambda: compiled_layernorm(x, w, eps=1e-6) + + def quack(self, args, kwargs) -> Any: + # Note: quack layernorm does not support bias + from quack.layernorm import layernorm + + x, w = args + return lambda: layernorm(x, w, eps=1e-6) + + def liger(self, args, kwargs) -> Any: + from liger_kernel.transformers.layer_norm import LigerLayerNorm + + x, w = args + M, N = x.shape + liger_layernorm = LigerLayerNorm(hidden_size=N, eps=1e-6).cuda() + liger_layernorm.weight.data.copy_(w) + liger_layernorm.bias.data.copy_( + torch.zeros(N, device="cuda", dtype=torch.float32) + ) + return lambda: liger_layernorm(x) + + def benchmark(self): + for M, N in self.get_shapes(): + print(f"Tensor dimensions: [{M}, {N}]") + torch_dtype = cutlass_torch.dtype(cutlass.BFloat16) + x = torch.randn(M, N, device="cuda", dtype=torch_dtype) + w = torch.randn(N, device="cuda", dtype=torch.float32) + self.benchmark_single_shape((x, w), setting=f"shape: [{M}, {N}]") + + +class LayerNormBackward(BenchmarkKernel): + def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): + super().__init__(compile_mode) + self.available_backends = ["eager", "compiled", "liger"] + + def get_shapes(self) -> tuple[tuple[int, ...], ...]: + # OOM for (16384, 131072), (8192, 262144) + return ( + (32768, 256), + (32768, 512), + (32768, 1024), + (32768, 2048), + (32768, 4096), + (32768, 8192), + (32768, 16384), + (32768, 32768), + (32768, 65536), + ) + + def get_memory_bytes(self, args, kwargs) -> int: + x, w, dy = args + M, N = x.shape + # Read x ([M, N]), w ([N]), dy ([M, N]), write dx ([M, N]), dw ([N]) + return ( + 2 * M * N * x.dtype.itemsize + + 2 * N * w.dtype.itemsize + + M * N * dy.dtype.itemsize + ) + + def layernorm_ref(self, x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6): + x_f32 = x.float() + return F.layer_norm(x_f32, w.shape, w, None, eps).to(x.dtype) + + def eager(self, args, kwargs=None) -> Any: + assert kwargs is None + x, w, dy = args + y = self.layernorm_ref(x, w) + return lambda: torch.autograd.grad( + y, [x, w], grad_outputs=dy, retain_graph=True + ) + + def compiled(self, args, kwargs=None) -> Any: + assert kwargs is None + x, w, dy = args + compiled_layernorm = torch.compile( + self.layernorm_ref, mode=self.compile_mode, fullgraph=True + ) + y = compiled_layernorm(x, w) + return lambda: torch.autograd.grad( + y, [x, w], grad_outputs=dy, retain_graph=True + ) + + def liger(self, args, kwargs) -> Any: + from liger_kernel.transformers.layer_norm import LigerLayerNorm + + x, w, dy = args + M, N = x.shape + liger_layernorm = LigerLayerNorm(hidden_size=N, eps=1e-6).cuda() + liger_layernorm.weight.data.copy_(w) + liger_layernorm.bias.data.copy_( + torch.zeros(N, device="cuda", dtype=torch.float32) + ) + y = liger_layernorm(x) + return lambda: torch.autograd.grad( + y, [x, liger_layernorm.weight], grad_outputs=dy, retain_graph=True + ) + + def benchmark(self): + for M, N in self.get_shapes(): + print(f"Tensor dimensions: [{M}, {N}]") + torch_dtype = cutlass_torch.dtype(cutlass.BFloat16) + x = torch.randn(M, N, device="cuda", dtype=torch_dtype, requires_grad=True) + w = torch.randn(N, device="cuda", dtype=torch.float32, requires_grad=True) + dy = torch.randn(M, N, device="cuda", dtype=torch_dtype) + self.benchmark_single_shape((x, w, dy), setting=f"shape: [{M}, {N}]") diff --git a/benchmarks/dynamo/genai_layers/requirements.txt b/benchmarks/dynamo/genai_layers/requirements.txt new file mode 100644 index 000000000000..ddd1f0101349 --- /dev/null +++ b/benchmarks/dynamo/genai_layers/requirements.txt @@ -0,0 +1,4 @@ +quack-kernels +liger-kernel +nvidia-cutlass-dsl==4.1.0.dev0 +matplotlib diff --git a/benchmarks/dynamo/genai_layers/utils.py b/benchmarks/dynamo/genai_layers/utils.py new file mode 100644 index 000000000000..e11995ee0b5f --- /dev/null +++ b/benchmarks/dynamo/genai_layers/utils.py @@ -0,0 +1,243 @@ +import os +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Callable, Optional + +import matplotlib.pyplot as plt + +import torch +from torch._inductor.runtime.benchmarking import benchmarker + + +def benchmark_kernel_in_milliseconds(func: Callable, *args, **kwargs) -> float: + # warmup + for _ in range(5): + func(*args, **kwargs) + with torch.compiler.set_stance("fail_on_recompile"): + return benchmarker.benchmark_gpu(lambda: func(*args, **kwargs)) + + +@dataclass +class Performance: + # Benchmark setting usually the shape of the input tensor + setting: str + + # Latency in milliseconds + latency: float + + # Number of memory access in bytes + memory_bytes: float + + # Memory bandwidth in GB/s + memory_bandwidth: float = 0.0 + + # Compute intensity in FLOPs/byte + compute_intensity: float = 0.0 + + def __post_init__(self): + self.memory_bandwidth = self.memory_bytes / (self.latency / 1000) / 1e9 + + def __str__(self): + return f"setting: {self.setting}, latency: {self.latency} ms, memory bandwidth: {self.memory_bandwidth} GB/s" + + +class BenchmarkKernel: + def __init__(self, compile_mode: str = "max-autotune-no-cudagraphs"): + self.name = self.__class__.__name__ + self.available_backends: list[str] = [] + self.compile_mode: str = compile_mode + + # mapping from backend to list of performance results + self.profiling_results: defaultdict[str, list[Performance]] = defaultdict(list) + + def get_memory_bytes(self, args, kwargs) -> int: + # Get the necessary memory access in bytes for the kernelßß + raise NotImplementedError + + def get_shapes(self) -> tuple[tuple[int, ...], ...]: + # Get a list of input shapes to benchmark the kernel + raise NotImplementedError + + def eager(self, args, kwargs) -> Any: + raise NotImplementedError + + def compiled(self, args, kwargs) -> Any: + raise NotImplementedError + + def helion(self, args, kwargs) -> Any: + raise NotImplementedError + + def quack(self, args, kwargs) -> Any: + raise NotImplementedError + + def liger(self, args, kwargs) -> Any: + raise NotImplementedError + + def triton(self, args, kwargs) -> Any: + raise NotImplementedError + + def benchmark(self): + raise NotImplementedError + + def clone_inputs(self, args, kwargs) -> Any: + args_ref = [ + arg.clone().detach().requires_grad_(arg.requires_grad) for arg in args + ] + + kwargs_ref = ( + { + k: ( + v.clone().detach().requires_grad_(v.requires_grad) + if isinstance(v, torch.Tensor) + else v + ) + for k, v in kwargs.items() + } + if kwargs + else kwargs + ) + + return args_ref, kwargs_ref + + def check_accuracy(self, args, kwargs) -> None: + res = {} + for backend in self.available_backends: + args_ref, kwargs_ref = self.clone_inputs(args, kwargs) + res[backend] = getattr(self, backend)(args_ref, kwargs_ref)() + gold = res["eager"] + for backend in self.available_backends: + if backend == "eager": + continue + try: + torch.testing.assert_close(res[backend], gold) + for t, gold_t in zip(res[backend], gold): + if t.requires_grad: + torch.testing.assert_close(t.grad, gold_t.grad) + print( + f"Accuracy check \033[92m✓ succeed\033[0m for {backend} backend on {self.name} kernel" + ) + except Exception as e: + print( + f"Accuracy check \033[91m✗ failed\033[0m for {backend} backend on {self.name} kernel. Error {e}" + ) + + def benchmark_single_shape( + self, args, kwargs=None, should_check_accuracy=True, setting: str = "" + ): + for backend in self.available_backends: + args_ref, kwargs_ref = self.clone_inputs(args, kwargs) + try: + avg_time = benchmark_kernel_in_milliseconds( + getattr(self, backend)(args_ref, kwargs_ref) + ) + except Exception as e: + print( + f"Failed to run {backend} backend on {self.name} kernel for {setting} due to {e}" + ) + self.available_backends.remove(backend) + continue + mem_bytes = self.get_memory_bytes(args_ref, kwargs_ref) + perf = Performance(setting, avg_time, mem_bytes) + print(f"{self.name} kernel on {backend} backend. {perf}") + self.profiling_results[backend].append(perf) + + if should_check_accuracy: + self.check_accuracy(args, kwargs) + + def visualize(self) -> None: + visualize_comparison( + self.profiling_results, + title=f"{self.name}", + output_path=f"{self.name}_bench", + ) + return + + +def get_backend_colors() -> dict[str, str]: + """Get consistent color scheme for different backends.""" + return { + "eager": "#1f77b4", # blue + "compiled": "#ff7f0e", # orange + "quack": "#2ca02c", # green + "liger": "#d62728", # red + "helion": "#9467bd", # purple + "triton": "#8c564b", # brown + "cutlass": "#e377c2", # pink + "flash_attn": "#7f7f7f", # gray + "default": "#000000", # black + } + + +def visualize_comparison( + profiling_results: dict[str, list[Performance]], + title: Optional[str] = None, + output_path: Optional[str] = None, +) -> None: + """ + Create a single memory_bandwidth comparison plot from profiling results. + + Args: + profiling_results: Dict mapping backend names to lists of Performance objects + output_path: Path to save the plot (optional) + """ + # Get backend colors + backend_colors = get_backend_colors() + + # Extract settings from eager backend which runs all settings + all_settings = [] + for perf in profiling_results["eager"]: + all_settings.append(perf.setting) + + # Create single plot + fig, ax = plt.subplots(1, 1, figsize=(12, 8)) + + for backend in profiling_results: + backend_perfs = profiling_results[backend] + perf_dict = {perf.setting: perf for perf in backend_perfs} + + x_vals = [] + y_vals = [] + for i, setting in enumerate(all_settings): + if setting in perf_dict: + x_vals.append(i) + y_vals.append(perf_dict[setting].memory_bandwidth) + + if x_vals: # Only plot if we have data + color = backend_colors.get(backend, backend_colors["default"]) + ax.plot( + x_vals, + y_vals, + "o-", + label=backend, + color=color, + linewidth=2, + markersize=8, + alpha=0.8, + ) + + # Configure the plot + ax.set_title(title or "Memory Bandwidth Comparison", fontsize=16) + ax.set_xlabel("Shape", fontsize=12) + ax.set_ylabel("memory bandwidth (GB/s)", fontsize=12) + ax.set_xticks(range(len(all_settings))) + ax.set_xticklabels( + [ + s.replace("shape: ", "").replace("[", "").replace("]", "") + for s in all_settings + ], + rotation=45, + ha="right", + ) + ax.legend(fontsize=10) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + + # Save the plot if output path is provided + if output_path: + # Save as PNG + os.makedirs("pics", exist_ok=True) + full_path = os.path.join("pics", output_path + ".png") + plt.savefig(full_path, dpi=300, bbox_inches="tight", facecolor="white") + + plt.close() diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index edc9d0f73d16..c0d676f88510 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -1,89 +1,89 @@ -add_loop_eager,compile_time_instruction_count,3017000000,0.015 +add_loop_eager,compile_time_instruction_count,3070000000,0.10 -add_loop_eager_dynamic,compile_time_instruction_count,4352000000,0.025 +add_loop_eager_dynamic,compile_time_instruction_count,4432000000,0.10 -add_loop_inductor,compile_time_instruction_count,29490000000,0.015 +add_loop_inductor,compile_time_instruction_count,30280000000,0.10 -add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38760000000,0.025 +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,39910000000,0.10 -add_loop_inductor_gpu,compile_time_instruction_count,26000000000,0.015 +add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.10 -basic_modules_ListOfLinears_eager,compile_time_instruction_count,947600000,0.015 +basic_modules_ListOfLinears_eager,compile_time_instruction_count,969100000,0.10 -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18490000000,0.015 +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18030000000,0.10 -basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.015 +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.10 -basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10297683772,0.2 +basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000000,0.2 -update_hint_regression,compile_time_instruction_count,1673000000,0.02 +update_hint_regression,compile_time_instruction_count,1719000000,0.10 -sum_floordiv_regression,compile_time_instruction_count,986800000,0.015 +sum_floordiv_regression,compile_time_instruction_count,966100000,0.10 -symint_sum,compile_time_instruction_count,3166000000,0.015 +symint_sum,compile_time_instruction_count,3237000000,0.10 -symint_sum_loop,compile_time_instruction_count,4202000000,0.015 +symint_sum_loop,compile_time_instruction_count,4299000000,0.10 -aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2103000000,0.015 +aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2151000000,0.10 -aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6004000000,0.015 +aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,6124000000,0.10 -aotdispatcher_partitioner_cpu,compile_time_instruction_count,8783000000,0.015 +aotdispatcher_partitioner_cpu,compile_time_instruction_count,9005000000,0.10 -aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1940000000,0.015 +aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1989000000,0.10 -aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3885000000,0.015 +aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3959000000,0.10 -aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10470000000,0.015 +aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10650000000,0.10 -mm_loop_inductor_gpu,compile_time_instruction_count,4324000000,0.015 +mm_loop_inductor_gpu,compile_time_instruction_count,4461000000,0.10 -mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8116000000,0.015 +mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.10 -basic_NestedModule_eager,compile_time_instruction_count,8152524390,0.015 +basic_NestedModule_eager,compile_time_instruction_count,8348000000,0.10 -basic_InlineMod_eager,compile_time_instruction_count,7255000000,0.015 +basic_InlineMod_eager,compile_time_instruction_count,7464000000,0.10 diff --git a/build_variables.bzl b/build_variables.bzl index 99290d5318cd..f6fba33dc4d4 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -625,6 +625,10 @@ libtorch_nativert_sources = [ "torch/nativert/executor/memory/AliasAnalyzer.cpp", "torch/nativert/executor/memory/LayoutPlanner.cpp", "torch/nativert/executor/memory/LayoutManager.cpp", + "torch/nativert/kernels/KernelRegistry.cpp", + "torch/nativert/kernels/NativeKernels.cpp", + "torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp", + "torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp", ] torch_mobile_tracer_sources = [ @@ -734,6 +738,7 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/UCCTracing.cpp", "torch/csrc/distributed/c10d/UCCUtils.cpp", "torch/csrc/distributed/c10d/cuda/AsyncMM.cu", + "torch/csrc/distributed/c10d/cuda/CUDAEventCache.cpp", "torch/csrc/distributed/c10d/cuda/utils.cpp", "torch/csrc/distributed/c10d/cuda/StreamBlock.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", diff --git a/c10/BUCK.oss b/c10/BUCK.oss index 4b2cbd049b85..4ec4ab5beabb 100644 --- a/c10/BUCK.oss +++ b/c10/BUCK.oss @@ -37,8 +37,6 @@ cxx_library( ), exported_linker_flags = [], exported_preprocessor_flags = [ - '-DC10_USING_CUSTOM_GENERATED_MACROS', - '-DC10_USE_GLOG', '-DC10_USE_MINIMAL_GLOG', '-DC10_MOBILE', '-fexceptions', diff --git a/c10/CMakeLists.txt b/c10/CMakeLists.txt index 8e9d267352dd..f82e460cafc3 100644 --- a/c10/CMakeLists.txt +++ b/c10/CMakeLists.txt @@ -18,16 +18,12 @@ else() set(C10_LIB c10) endif() - # ---[ Configure macro file. - set(C10_USE_GFLAGS ${USE_GFLAGS}) # used in cmake_macros.h.in - set(C10_USE_GLOG ${USE_GLOG}) # used in cmake_macros.h.in - set(C10_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) # used in cmake_macros.h.in - set(C10_USE_NUMA ${USE_NUMA}) - set(C10_USE_MSVC_STATIC_RUNTIME ${CAFFE2_USE_MSVC_STATIC_RUNTIME}) - set(C10_USE_ROCM_KERNEL_ASSERT ${USE_ROCM_KERNEL_ASSERT}) - configure_file( - ${CMAKE_CURRENT_LIST_DIR}/macros/cmake_macros.h.in - ${CMAKE_BINARY_DIR}/c10/macros/cmake_macros.h) +set(C10_USE_GFLAGS ${USE_GFLAGS}) # also used in torch/headeronly +set(C10_USE_GLOG ${USE_GLOG}) # also used in torch/headeronly +set(C10_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) # also used in torch/headeronly +set(C10_USE_NUMA ${USE_NUMA}) # also used in torch/headeronly +set(C10_USE_MSVC_STATIC_RUNTIME ${CAFFE2_USE_MSVC_STATIC_RUNTIME}) # also used in torch/headeronly +set(C10_USE_ROCM_KERNEL_ASSERT ${USE_ROCM_KERNEL_ASSERT}) # also used in torch/headeronly # Note: if you want to add ANY dependency to the c10 library, make sure you # check with the core PyTorch developers as the dependency will be @@ -94,6 +90,8 @@ if(NOT BUILD_LIBTORCHLESS) if(C10_USE_GLOG) target_link_libraries(c10 PUBLIC glog::glog) endif() + + target_link_libraries(c10 PUBLIC headeronly) target_link_libraries(c10 PRIVATE fmt::fmt-header-only) target_link_libraries(c10 PRIVATE nlohmann) target_link_libraries(c10 PRIVATE moodycamel) @@ -170,8 +168,6 @@ endif() install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR} DESTINATION include FILES_MATCHING PATTERN "*.h") -install(FILES ${CMAKE_BINARY_DIR}/c10/macros/cmake_macros.h - DESTINATION include/c10/macros) if(MSVC AND C10_BUILD_SHARED_LIBS) install(FILES $ DESTINATION lib OPTIONAL) diff --git a/c10/core/AllocatorConfig.cpp b/c10/core/AllocatorConfig.cpp deleted file mode 100644 index f0a3134fae68..000000000000 --- a/c10/core/AllocatorConfig.cpp +++ /dev/null @@ -1,242 +0,0 @@ -#include -#include -#include -#include - -namespace c10::CachingAllocator { - -namespace { -constexpr size_t kRoundUpPowerOfTwoIntervals = 16; -constexpr size_t kMB = 1024 * 1024ul; -constexpr size_t kRoundUpPowerOfTwoStart = 1 * kMB; // 1MB -constexpr size_t kRoundUpPowerOfTwoEnd = 64 * 1024ul * kMB; // 64GB -} // anonymous namespace - -AcceleratorAllocatorConfig& AcceleratorAllocatorConfig::instance() { - static AcceleratorAllocatorConfig instance; -#define C10_ALLOCATOR_CONFIG_PARSE_ENV(env, deprecated) \ - auto env##_name = c10::utils::get_env(#env); \ - if (env##_name.has_value()) { \ - if (deprecated) { \ - TORCH_WARN_ONCE(#env " is deprecated, use PYTORCH_ALLOC_CONF instead"); \ - } \ - instance.parseArgs(env##_name.value()); \ - return true; \ - } - static bool env_flag [[maybe_unused]] = []() { - C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_ALLOC_CONF, false) - // Keep this for backwards compatibility - C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_CUDA_ALLOC_CONF, /*deprecated=*/true) - C10_ALLOCATOR_CONFIG_PARSE_ENV(PYTORCH_HIP_ALLOC_CONF, /*deprecated=*/true) - return false; - }(); -#undef C10_ALLOCATOR_CONFIG_PARSE_ENV - return instance; -} - -AcceleratorAllocatorConfig::AcceleratorAllocatorConfig() { - roundup_power2_divisions_.assign(kRoundUpPowerOfTwoIntervals, 0); -} - -size_t AcceleratorAllocatorConfig::roundup_power2_divisions(size_t size) { - size_t log_size = (63 - llvm::countLeadingZeros(size)); - - // Our intervals start at 1MB and end at 64GB - const size_t interval_start = - 63 - llvm::countLeadingZeros(kRoundUpPowerOfTwoStart); - const size_t interval_end = - 63 - llvm::countLeadingZeros(kRoundUpPowerOfTwoEnd); - TORCH_CHECK( - interval_end - interval_start == kRoundUpPowerOfTwoIntervals, - "kRoundUpPowerOfTwoIntervals mismatch"); - - size_t index = - (log_size > interval_start) ? (log_size - interval_start) : 0ul; - index = std::min(index, kRoundUpPowerOfTwoIntervals - 1); - return instance().roundup_power2_divisions_[index]; -} - -size_t AcceleratorAllocatorConfig::parseMaxSplitSize( - const ConfigTokenizer& tokenizer, - size_t i) { - tokenizer.checkToken(++i, ":"); - constexpr size_t min_allowed_split_size_mb = kLargeBuffer / kMB; - constexpr size_t max_allowed_split_size_mb = - std::numeric_limits::max() / kMB; - - size_t val_env = tokenizer.toSizeT(++i); - TORCH_CHECK( - val_env >= min_allowed_split_size_mb, - "CachingAllocator option max_split_size_mb too small, must be >= ", - min_allowed_split_size_mb); - val_env = std::min(val_env, max_allowed_split_size_mb); - max_split_size_ = val_env * kMB; - - return i; -} - -size_t AcceleratorAllocatorConfig::parseMaxNonSplitRoundingSize( - const ConfigTokenizer& tokenizer, - size_t i) { - tokenizer.checkToken(++i, ":"); - constexpr size_t min_allowed_split_size_mb = kLargeBuffer / kMB; - constexpr size_t max_allowed_split_size_mb = - std::numeric_limits::max() / kMB; - - size_t val_env = tokenizer.toSizeT(++i); - TORCH_CHECK( - val_env >= min_allowed_split_size_mb, - "CachingAllocator option max_non_split_rounding_mb too small, must be >= ", - min_allowed_split_size_mb); - val_env = std::min(val_env, max_allowed_split_size_mb); - max_non_split_rounding_size_ = val_env * kMB; - - return i; -} - -size_t AcceleratorAllocatorConfig::parseGarbageCollectionThreshold( - const ConfigTokenizer& tokenizer, - size_t i) { - tokenizer.checkToken(++i, ":"); - double val_env = tokenizer.toDouble(++i); - TORCH_CHECK( - val_env > 0 && val_env < 1.0, - "garbage_collect_threshold is invalid, set it in (0.0, 1.0)"); - garbage_collection_threshold_ = val_env; - - return i; -} - -size_t AcceleratorAllocatorConfig::parseRoundUpPower2Divisions( - const ConfigTokenizer& tokenizer, - size_t i) { - tokenizer.checkToken(++i, ":"); - bool first_value = true; - - if (tokenizer[++i] == "[") { - size_t last_index = 0; - // NOLINTNEXTLINE(bugprone-inc-dec-in-conditions) - while (++i < tokenizer.size() && tokenizer[i] != "]") { - size_t value_index = i; - tokenizer.checkToken(++i, ":"); - size_t value = tokenizer.toSizeT(++i); - TORCH_CHECK( - value == 0 || llvm::isPowerOf2_64(value), - "For roundups, the divisions has to be power of 2 or 0 to disable roundup "); - - if (tokenizer[value_index] == ">") { - std::fill( - std::next( - roundup_power2_divisions_.begin(), - static_cast::difference_type>( - last_index + 1)), - roundup_power2_divisions_.end(), - value); - } else { - size_t boundary = tokenizer.toSizeT(value_index); - TORCH_CHECK( - llvm::isPowerOf2_64(boundary), - "For roundups, the intervals have to be power of 2 "); - - size_t index = 63 - llvm::countLeadingZeros(boundary); - index = - std::clamp(index, size_t{0}, roundup_power2_divisions_.size() - 1); - - if (first_value) { - std::fill( - roundup_power2_divisions_.begin(), - std::next( - roundup_power2_divisions_.begin(), - static_cast::difference_type>(index)), - value); - first_value = false; - } - roundup_power2_divisions_[index] = value; - last_index = index; - } - - if (tokenizer[i + 1] != "]") { - tokenizer.checkToken(++i, ","); - } - } - TORCH_INTERNAL_ASSERT( - i < tokenizer.size(), - "Expected closing bracket ']' in ConfigTokenizer but reached end of config"); - } else { // Keep this for backwards compatibility - size_t value = tokenizer.toSizeT(i); - TORCH_CHECK( - llvm::isPowerOf2_64(value), - "For roundups, the divisions has to be power of 2 "); - std::fill( - roundup_power2_divisions_.begin(), - roundup_power2_divisions_.end(), - value); - } - return i; -} - -size_t AcceleratorAllocatorConfig::parseExpandableSegments( - const ConfigTokenizer& tokenizer, - size_t i) { - tokenizer.checkToken(++i, ":"); - use_expandable_segments_ = tokenizer.toBool(++i); - - return i; -} - -size_t AcceleratorAllocatorConfig::parsePinnedUseBackgroundThreads( - const ConfigTokenizer& tokenizer, - size_t i) { - tokenizer.checkToken(++i, ":"); - pinned_use_background_threads_ = tokenizer.toBool(++i); - - return i; -} - -void AcceleratorAllocatorConfig::parseArgs(const std::string& env) { - // The following option will be reset to its default value if not explicitly - // set each time. - max_split_size_ = std::numeric_limits::max(); - roundup_power2_divisions_.assign(kRoundUpPowerOfTwoIntervals, 0); - garbage_collection_threshold_ = 0; - - { - std::lock_guard lock(last_allocator_settings_mutex_); - last_allocator_settings_ = env; - } - - ConfigTokenizer tokenizer(env); - for (size_t i = 0; i < tokenizer.size(); i++) { - const auto& key = tokenizer[i]; - if (key == "max_split_size_mb") { - i = parseMaxSplitSize(tokenizer, i); - } else if (key == "max_non_split_rounding_mb") { - i = parseMaxNonSplitRoundingSize(tokenizer, i); - } else if (key == "garbage_collection_threshold") { - i = parseGarbageCollectionThreshold(tokenizer, i); - } else if (key == "roundup_power2_divisions") { - i = parseRoundUpPower2Divisions(tokenizer, i); - } else if (key == "expandable_segments") { - i = parseExpandableSegments(tokenizer, i); - } else if (key == "pinned_use_background_threads") { - i = parsePinnedUseBackgroundThreads(tokenizer, i); - } else { - // If a device-specific configuration parser hook is registered, it will - // check if the key is unrecognized. - if (device_config_parser_hook_) { - TORCH_CHECK( - keys_.find(key) != keys_.end(), - "Unrecognized key '", - key, - "' in Accelerator allocator config."); - } - i = tokenizer.skipKey(i); - } - - if (i + 1 < tokenizer.size()) { - tokenizer.checkToken(++i, ","); - } - } -} - -} // namespace c10::CachingAllocator diff --git a/c10/core/AllocatorConfig.h b/c10/core/AllocatorConfig.h deleted file mode 100644 index eddaaa5ffc6c..000000000000 --- a/c10/core/AllocatorConfig.h +++ /dev/null @@ -1,370 +0,0 @@ -#pragma once - -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace c10::CachingAllocator { - -// "large" allocations may be packed in 20 MiB blocks -const size_t kLargeBuffer = 20971520; - -// A utility class for tokenizing allocator configuration strings into discrete -// parts. For example, the config string: -// "key1:val1,key2:[val2,val3]" -// is tokenized into: -// "key1", ":", "val1", ",", "key2", ":", "[", "val2", ",", "val3", "]", -// -// Tokens include keys, values, and special characters (':', ',', '[', ']'). -// Whitespace is ignored. -class ConfigTokenizer { - public: - explicit ConfigTokenizer(const std::string& env) { - std::string buffer; - for (char ch : env) { - if (ch == ',' || ch == ':' || ch == '[' || ch == ']') { - if (!buffer.empty()) { - config_.emplace_back(std::move(buffer)); - buffer.clear(); - } - config_.emplace_back(1, ch); - } else if (!std::isspace(static_cast(ch))) { - buffer += ch; - } - } - if (!buffer.empty()) { - config_.emplace_back(std::move(buffer)); - } - } - - const std::string& operator[](size_t i) const { - TORCH_INTERNAL_ASSERT( - i < config_.size(), "Index out of bounds in ConfigTokenizer"); - return config_[i]; - } - - size_t size() const { - return config_.size(); - } - - bool checkToken(size_t i, const std::string& token) const { - checkIndex(i); - return config_[i] == token; - } - - size_t toSizeT(size_t i) const { - checkIndex(i); - return std::stoull(config_[i]); - } - - double toDouble(size_t i) const { - checkIndex(i); - return std::stod(config_[i]); - } - - bool toBool(size_t i) const { - checkIndex(i); - const auto& token = config_[i]; - if (token == "True") { - return true; - } else if (token == "False") { - return false; - } else { - TORCH_CHECK( - false, - "Expected 'True' or 'False' at index ", - i, - " in ConfigTokenizer but got '", - token, - "'"); - } - } - - // Skips the current token group and returns the index of the value token. - // Assumes the current index `i` points to a key name in a key-value pair. - size_t skipKey(size_t i) const { - // Expect a colon after the key - checkToken(++i, ":"); - - ++i; // Move to the value - checkIndex(i); - if (config_[i] != "[") { - // Value is a single token (not a list) -> return its index - return i; - } - - // Skip tokens inside the list until matching ']' - // NOLINTNEXTLINE(bugprone-inc-dec-in-conditions) - while (++i < config_.size() && config_[i] != "]") { - } - - TORCH_INTERNAL_ASSERT( - i < config_.size(), - "Expected closing bracket ']' in ConfigTokenizer but reached end of config"); - - return i; // Return the index of the closing ']' - } - - private: - void checkIndex(size_t i) const { - TORCH_INTERNAL_ASSERT( - i < config_.size(), "Index out of bounds in ConfigTokenizer"); - } - - std::vector config_; -}; - -/** - * Note [AcceleratorAllocatorConfig design] - * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - * This class configures memory allocation for both device and host memory. A - * single `AcceleratorAllocatorConfig` instance is shared across all accelerator - * backends, such as CUDA and XPU, under the assumption that relevant - * environment variables apply uniformly to all accelerators. Device-specific - * configuration extensions are supported via hooks (see - * `registerDeviceConfigParserHook`). - * - * Recommended design: - * - Place common configurations in `AcceleratorAllocatorConfig`. - * - Extend backend-specific configurations in corresponding device-specific - * classes, such as `CUDAAllocatorConfig`, etc. - * - * Scope: - * - Configuration options must be environment-variable driven. - * - * Naming Convention: - * - Public API names in `AcceleratorAllocatorConfig` should be device-generic. - * - Members prefixed with `pinned_` are specific to the host/pinned allocator. - * - Environment variable names should be generic across backends. - * - Comma-separated key-value pairs in the format: `key:value`. Use square - * brackets `[]` for list values Example: `key1:123, key2:[val1,val2]` - * - * Environment Variables: - * - The primary environment variable for configuration is `PYTORCH_ALLOC_CONF`. - * - For backward compatibility, `PYTORCH_CUDA_ALLOC_CONF` is also supported - * with lower priority. - */ - -class C10_API AcceleratorAllocatorConfig { - public: - static AcceleratorAllocatorConfig& instance(); - - C10_DISABLE_COPY_AND_ASSIGN(AcceleratorAllocatorConfig); - AcceleratorAllocatorConfig(AcceleratorAllocatorConfig&&) = delete; - AcceleratorAllocatorConfig& operator=(AcceleratorAllocatorConfig&&) = delete; - ~AcceleratorAllocatorConfig() = default; - - /* Device allocator settings */ - - // Returns the maximum block size (in MB) that is allowed to be split. The - // default is unlimited (all blocks can be split). - static size_t max_split_size() { - return instance().max_split_size_; - } - - // Returns the maximum block size (in MB) that is allowed to be rounded up - // without requiring splitting when searching for a free block. The default is - // 20 MiB. - static size_t max_non_split_rounding_size() { - return instance().max_non_split_rounding_size_; - } - - // Return the number of divisions used when rounding up allocation sizes (in - // MB) to the nearest power-of-2 boundary. - static size_t roundup_power2_divisions(size_t size); - - // Returns the vector of division factors used for rounding up allocation - // sizes. These divisions apply to size intervals between 1MB and 64GB. - static const std::vector& roundup_power2_divisions() { - return instance().roundup_power2_divisions_; - } - - // Returns the threshold that triggers garbage collection when the ratio of - // used memory to maximum allowed memory exceeds this value. The default is 0, - // meaning no garbage collection is triggered. The value should be in the - // range (0.0, 1.0). - static double garbage_collection_threshold() { - return instance().garbage_collection_threshold_; - } - - // Returns whether the expandable segment feature is enabled. This allows the - // allocator to start with one segment that grows as needed, rather than - // creating a new segment for each allocation. Default is false (expandable - // segments disabled). - static bool use_expandable_segments() { - return instance().use_expandable_segments_; - } - - /* Host allocator settings */ - - // Returns whether the pinned host allocator uses background threads for - // processing events. This is useful for improving performance in scenarios - // where many small allocations are made. Default is false (background threads - // disabled). - static bool pinned_use_background_threads() { - return instance().pinned_use_background_threads_; - } - - /* Settings for both device and host allocator */ - - // Returns the current allocator settings as a string. This string is useful - // to expand device-specific allocator configurations - static std::string last_allocator_settings() { - std::lock_guard lock(instance().last_allocator_settings_mutex_); - return instance().last_allocator_settings_; - } - - // Returns the set of valid keys for the allocator configuration. - // This set is used to validate the presence and correctness of keys in - // device-specific configuration parsers. - static const std::unordered_set& getKeys() { - return instance().keys_; - } - - // Parses the environment variable `env` to update the allocator settings. - // If the environment variable is not set, it does nothing. - // The configuration string should be a comma-separated list of key-value - // pairs, where each key is a configuration option and the value is the - // corresponding setting. For example: - // "max_split_size_mb:100,max_non_split_rounding_mb:20,garbage_collection_threshold:0.5,roundup_power2_divisions:[64:8,256:4,1024:4,>:1],expandable_segments:true,pinned_use_background_threads:true" - void parseArgs(const std::string& env); - - // Registers a device-specific configuration parser hook and its key. This - // allows backends to parse additional device-specific configuration options - // from the environment variable. The hook should be a function that takes a - // string (the environment variable value) and parses it to set - // device-specific configuration options. The hook will be called when the - // environment variable is parsed. If a hook is already registered, it will be - // replaced with the new one. - void registerDeviceConfigParserHook( - std::function&& hook, - const std::unordered_set& keys) { - device_config_parser_hook_ = std::move(hook); - for (auto& key : keys) { - TORCH_CHECK( - keys_.insert(key).second, - "Duplicated key '", - key, - "' found in device-specific configuration parser hook registration"); - } - } - - // Calls the registered device-specific configuration parser hook with the - // provided environment string. This allows backends to parse additional - // device-specific configuration options from the environment variable. - // If no hook is registered, this function does nothing. - void callDeviceConfigParserHook(const std::string& env) const { - if (device_config_parser_hook_) { - device_config_parser_hook_(env); - } - } - - private: - AcceleratorAllocatorConfig(); - - /* Internal functions for device allocator */ - - // Parse `max_split_size_mb` from environment variable. - size_t parseMaxSplitSize(const ConfigTokenizer& tokenizer, size_t i); - // Parse `max_non_split_rounding_mb` from environment variable. - size_t parseMaxNonSplitRoundingSize( - const ConfigTokenizer& tokenizer, - size_t i); - // Parse `garbage_collection_threshold` from environment variable. - size_t parseGarbageCollectionThreshold( - const ConfigTokenizer& tokenizer, - size_t i); - // Parse `roundup_power2_divisions` from environment variable. - size_t parseRoundUpPower2Divisions( - const ConfigTokenizer& tokenizer, - size_t i); - // Parse `expandable_segments` from environment variable. - size_t parseExpandableSegments(const ConfigTokenizer& tokenizer, size_t i); - - /* Internal functions for host allocator */ - - // Parse `pinned_use_background_threads` from environment variable. - size_t parsePinnedUseBackgroundThreads( - const ConfigTokenizer& tokenizer, - size_t i); - - /* The following members are specifically used for the device allocator. */ - - // The maximum block size that is allowed to be split. - std::atomic max_split_size_{std::numeric_limits::max()}; - // The maximum allowable extra size of a memory block without requiring - // splitting when searching for a free block. - std::atomic max_non_split_rounding_size_{kLargeBuffer}; - // Used to store how memory allocations of different sizes should be rounded - // up to the nearest power of 2 divisions. - std::vector roundup_power2_divisions_; - // The threshold that triggers garbage collection when the ratio of used - // memory to maximum allowed memory exceeds this value. - std::atomic garbage_collection_threshold_{0}; - // A flag to enable expandable segments feature. - std::atomic use_expandable_segments_{false}; - - /* The following members are specifically used for the host allocator. */ - - // A flag to enable background thread for processing events. - std::atomic pinned_use_background_threads_{false}; - - /* The following members are used for both device and host allocator. */ - - // Record the last allocator config environment setting. - std::mutex last_allocator_settings_mutex_; - std::string last_allocator_settings_; - - // Optional hook for parsing additional device-specific allocator settings. - // This allows backends (e.g., CUDA, XPU) to register a custom parser for - // their own environment configuration extensions. - std::function device_config_parser_hook_{nullptr}; - - // A set of valid configuration keys, including both common and - // device-specific options. This set is used to validate the presence and - // correctness of keys during parsing. - std::unordered_set keys_{ - "max_split_size_mb", - "max_non_split_rounding_mb", - "garbage_collection_threshold", - "roundup_power2_divisions", - "expandable_segments", - "pinned_use_background_threads"}; -}; - -C10_API inline void setAllocatorSettings(const std::string& env) { - AcceleratorAllocatorConfig::instance().parseArgs(env); - AcceleratorAllocatorConfig::instance().callDeviceConfigParserHook(env); -} - -C10_API inline std::string getAllocatorSettings() { - return AcceleratorAllocatorConfig::instance().last_allocator_settings(); -} - -struct DeviceConfigParserHookRegistry { - explicit DeviceConfigParserHookRegistry( - std::function&& hook, - const std::unordered_set& keys) { - AcceleratorAllocatorConfig::instance().registerDeviceConfigParserHook( - std::move(hook), keys); - } -}; - -// Assume each config parser has `parseArgs` and `getKeys` methods -#define REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(parser_cls) \ - namespace { \ - static at::CachingAllocator::DeviceConfigParserHookRegistry \ - g_device_config_parse_hook_registry_instance( \ - [](const std::string& env) { \ - parser_cls::instance().parseArgs(env); \ - }, \ - parser_cls::getKeys()); \ - } - -} // namespace c10::CachingAllocator diff --git a/c10/core/Device.cpp b/c10/core/Device.cpp index ee51154b420b..68fa6f91979a 100644 --- a/c10/core/Device.cpp +++ b/c10/core/Device.cpp @@ -41,6 +41,9 @@ DeviceType parse_type(const std::string& device_string) { "'mkldnn' is no longer used as device type. So torch.device('mkldnn') will be " "deprecated and removed in the future. Please use other valid device types instead."); } + if (device_string == get_privateuse1_backend()) { + return DeviceType::PrivateUse1; + } auto device = std::find_if( types.begin(), types.end(), @@ -50,9 +53,6 @@ DeviceType parse_type(const std::string& device_string) { if (device != types.end()) { return device->second; } - if (device_string == get_privateuse1_backend()) { - return DeviceType::PrivateUse1; - } std::vector device_names; for (const auto& it : types) { if (it.first) { diff --git a/c10/core/impl/SizesAndStrides.h b/c10/core/impl/SizesAndStrides.h index b8a4de1c2d89..6cc87e1d6be3 100644 --- a/c10/core/impl/SizesAndStrides.h +++ b/c10/core/impl/SizesAndStrides.h @@ -64,6 +64,10 @@ class C10_API SizesAndStrides { storageBytes(size_))); } + bool operator!=(const SizesAndStrides& other) const { + return !(*this == other); + } + SizesAndStrides& operator=(const SizesAndStrides& rhs) { if (this == &rhs) { return *this; diff --git a/c10/cuda/CUDAAllocatorConfig.cpp b/c10/cuda/CUDAAllocatorConfig.cpp index ac59059c2cc2..d2efb8c593e4 100644 --- a/c10/cuda/CUDAAllocatorConfig.cpp +++ b/c10/cuda/CUDAAllocatorConfig.cpp @@ -1,121 +1,389 @@ #include -#include -#include +#include +#include #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) #include #endif -#include - namespace c10::cuda::CUDACachingAllocator { -size_t CUDAAllocatorConfig::parseAllocatorConfig( - const c10::CachingAllocator::ConfigTokenizer& tokenizer, +constexpr size_t kRoundUpPowerOfTwoIntervals = 16; + +CUDAAllocatorConfig::CUDAAllocatorConfig() + : m_max_split_size(std::numeric_limits::max()), + m_max_non_split_rounding_size(kLargeBuffer), + m_garbage_collection_threshold(0), + m_pinned_num_register_threads(1), + m_expandable_segments(false), +#if CUDA_VERSION >= 12030 + m_expandable_segments_handle_type( + Expandable_Segments_Handle_Type::UNSPECIFIED), +#else + m_expandable_segments_handle_type( + Expandable_Segments_Handle_Type::POSIX_FD), +#endif + m_release_lock_on_cudamalloc(false), + m_pinned_use_cuda_host_register(false), + m_pinned_use_background_threads(false) { + m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0); +} + +size_t CUDAAllocatorConfig::roundup_power2_divisions(size_t size) { + size_t log_size = (63 - llvm::countLeadingZeros(size)); + + // Our intervals start at 1MB and end at 64GB + const size_t interval_start = + 63 - llvm::countLeadingZeros(static_cast(1048576)); + const size_t interval_end = + 63 - llvm::countLeadingZeros(static_cast(68719476736)); + TORCH_CHECK( + (interval_end - interval_start == kRoundUpPowerOfTwoIntervals), + "kRoundUpPowerOfTwoIntervals mismatch"); + + int index = static_cast(log_size) - static_cast(interval_start); + + index = std::max(0, index); + index = std::min(index, static_cast(kRoundUpPowerOfTwoIntervals) - 1); + return instance().m_roundup_power2_divisions[index]; +} + +void CUDAAllocatorConfig::lexArgs( + const std::string& env, + std::vector& config) { + std::vector buf; + + for (char ch : env) { + if (ch == ',' || ch == ':' || ch == '[' || ch == ']') { + if (!buf.empty()) { + config.emplace_back(buf.begin(), buf.end()); + buf.clear(); + } + config.emplace_back(1, ch); + } else if (ch != ' ') { + buf.emplace_back(ch); + } + } + if (!buf.empty()) { + config.emplace_back(buf.begin(), buf.end()); + } +} + +void CUDAAllocatorConfig::consumeToken( + const std::vector& config, + size_t i, + const char c) { + TORCH_CHECK( + i < config.size() && config[i] == std::string(1, c), + "Error parsing CachingAllocator settings, expected ", + c, + ""); +} + +size_t CUDAAllocatorConfig::parseMaxSplitSize( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + constexpr int mb = 1024 * 1024; + if (++i < config.size()) { + size_t val1 = stoi(config[i]); + TORCH_CHECK( + val1 > kLargeBuffer / mb, + "CachingAllocator option max_split_size_mb too small, must be > ", + kLargeBuffer / mb, + ""); + val1 = std::max(val1, kLargeBuffer / mb); + val1 = std::min(val1, (std::numeric_limits::max() / mb)); + m_max_split_size = val1 * 1024 * 1024; + } else { + TORCH_CHECK(false, "Error, expecting max_split_size_mb value", ""); + } + return i; +} + +size_t CUDAAllocatorConfig::parseMaxNonSplitRoundingSize( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + constexpr int mb = 1024 * 1024; + if (++i < config.size()) { + size_t val1 = stoi(config[i]); + TORCH_CHECK( + val1 > kLargeBuffer / mb, + "CachingAllocator option max_non_split_rounding_mb too small, must be > ", + kLargeBuffer / mb, + ""); + val1 = std::max(val1, kLargeBuffer / mb); + val1 = std::min(val1, (std::numeric_limits::max() / mb)); + m_max_non_split_rounding_size = val1 * 1024 * 1024; + } else { + TORCH_CHECK(false, "Error, expecting max_non_split_rounding_mb value", ""); + } + return i; +} + +size_t CUDAAllocatorConfig::parseGarbageCollectionThreshold( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + double val1 = stod(config[i]); + TORCH_CHECK( + val1 > 0, "garbage_collect_threshold too small, set it 0.0~1.0", ""); + TORCH_CHECK( + val1 < 1.0, "garbage_collect_threshold too big, set it 0.0~1.0", ""); + m_garbage_collection_threshold = val1; + } else { + TORCH_CHECK( + false, "Error, expecting garbage_collection_threshold value", ""); + } + return i; +} + +size_t CUDAAllocatorConfig::parseRoundUpPower2Divisions( + const std::vector& config, size_t i) { + consumeToken(config, ++i, ':'); + bool first_value = true; + + if (++i < config.size()) { + if (std::string_view(config[i]) == "[") { + size_t last_index = 0; + // NOLINTNEXTLINE(bugprone-inc-dec-in-conditions) + while (++i < config.size() && std::string_view(config[i]) != "]") { + const std::string& val1 = config[i]; + size_t val2 = 0; + + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + val2 = stoi(config[i]); + } else { + TORCH_CHECK( + false, "Error parsing roundup_power2_divisions value", ""); + } + TORCH_CHECK( + val2 == 0 || llvm::isPowerOf2_64(val2), + "For roundups, the divisions has to be power of 2 or 0 to disable roundup ", + ""); + + if (std::string_view(val1) == ">") { + std::fill( + std::next( + m_roundup_power2_divisions.begin(), + static_cast::difference_type>( + last_index)), + m_roundup_power2_divisions.end(), + val2); + } else { + size_t val1_long = stoul(val1); + TORCH_CHECK( + llvm::isPowerOf2_64(val1_long), + "For roundups, the intervals have to be power of 2 ", + ""); + + size_t index = 63 - llvm::countLeadingZeros(val1_long); + index = std::max((size_t)0, index); + index = std::min(index, m_roundup_power2_divisions.size() - 1); + + if (first_value) { + std::fill( + m_roundup_power2_divisions.begin(), + std::next( + m_roundup_power2_divisions.begin(), + static_cast::difference_type>( + index)), + val2); + first_value = false; + } + if (index < m_roundup_power2_divisions.size()) { + m_roundup_power2_divisions[index] = val2; + } + last_index = index; + } + + if (std::string_view(config[i + 1]) != "]") { + consumeToken(config, ++i, ','); + } + } + } else { // Keep this for backwards compatibility + size_t val1 = stoi(config[i]); + TORCH_CHECK( + llvm::isPowerOf2_64(val1), + "For roundups, the divisions has to be power of 2 ", + ""); + std::fill( + m_roundup_power2_divisions.begin(), + m_roundup_power2_divisions.end(), + val1); + } + } else { + TORCH_CHECK(false, "Error, expecting roundup_power2_divisions value", ""); + } + return i; +} + +size_t CUDAAllocatorConfig::parseAllocatorConfig( + const std::vector& config, + size_t i, + bool& used_cudaMallocAsync) { // For ease of maintenance and understanding, the CUDA and ROCm // implementations of this function are separated. This avoids having many // #ifdef's throughout. +#ifdef USE_ROCM // Ease burden on ROCm users by allowing either cuda or hip tokens. // cuda token is broken up to prevent hipify matching it. #define PYTORCH_TOKEN1 \ "cud" \ "aMallocAsync" #define PYTORCH_TOKEN2 "hipMallocAsync" - tokenizer.checkToken(++i, ":"); - i++; // Move to the value after the colon - TORCH_CHECK( - ((tokenizer[i] == "native") || (tokenizer[i] == PYTORCH_TOKEN1) || - (tokenizer[i] == PYTORCH_TOKEN2)), - "Unknown allocator backend, " - "options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2); - if (m_is_allocator_loaded) { - bool aync_allocator_at_runtime = (tokenizer[i] != "native"); + consumeToken(config, ++i, ':'); + if (++i < config.size()) { TORCH_CHECK( - aync_allocator_at_runtime == m_use_async_allocator, - "Allocator async backend parsed at runtime != allocator async backend parsed at load time, ", - aync_allocator_at_runtime, + ((config[i] == "native") || (config[i] == PYTORCH_TOKEN1) || + (config[i] == PYTORCH_TOKEN2)), + "Unknown allocator backend, " + "options are native, " PYTORCH_TOKEN1 ", and " PYTORCH_TOKEN2); + used_cudaMallocAsync = + (config[i] == PYTORCH_TOKEN1 || config[i] == PYTORCH_TOKEN2); + TORCH_INTERNAL_ASSERT( + config[i] == get()->name() || + (config[i] == PYTORCH_TOKEN1 && get()->name() == PYTORCH_TOKEN2), + "Allocator backend parsed at runtime != " + "allocator backend parsed at load time, ", + config[i], " != ", - m_use_async_allocator); + get()->name()); + } else { + TORCH_CHECK(false, "Error parsing backend value", ""); } - m_use_async_allocator = - (tokenizer[i] == PYTORCH_TOKEN1 || tokenizer[i] == PYTORCH_TOKEN2); - // CUDA allocator is always loaded at the start of the program - m_is_allocator_loaded = true; - -#if defined(CUDA_VERSION) - if (m_use_async_allocator) { -#if CUDA_VERSION >= 11040 - int version = 0; - C10_CUDA_CHECK(cudaDriverGetVersion(&version)); + return i; +#undef PYTORCH_TOKEN1 +#undef PYTORCH_TOKEN2 +#else // USE_ROCM + consumeToken(config, ++i, ':'); + if (++i < config.size()) { TORCH_CHECK( - version >= 11040, - "backend:cudaMallocAsync requires CUDA runtime " - "11.4 or newer, but cudaDriverGetVersion returned ", - version); + ((config[i] == "native") || (config[i] == "cudaMallocAsync")), + "Unknown allocator backend, " + "options are native and cudaMallocAsync"); + used_cudaMallocAsync = (config[i] == "cudaMallocAsync"); + if (used_cudaMallocAsync) { +#if CUDA_VERSION >= 11040 + int version = 0; + C10_CUDA_CHECK(cudaDriverGetVersion(&version)); + TORCH_CHECK( + version >= 11040, + "backend:cudaMallocAsync requires CUDA runtime " + "11.4 or newer, but cudaDriverGetVersion returned ", + version); #else - TORCH_CHECK( - false, - "backend:cudaMallocAsync requires PyTorch to be built with " - "CUDA 11.4 or newer, but CUDA_VERSION is ", - CUDA_VERSION); + TORCH_CHECK( + false, + "backend:cudaMallocAsync requires PyTorch to be built with " + "CUDA 11.4 or newer, but CUDA_VERSION is ", + CUDA_VERSION); #endif + } + TORCH_INTERNAL_ASSERT( + config[i] == get()->name(), + "Allocator backend parsed at runtime != " + "allocator backend parsed at load time"); + } else { + TORCH_CHECK(false, "Error parsing backend value", ""); } -#endif - return i; -#undef PYTORCH_TOKEN1 -#undef PYTORCH_TOKEN2 +#endif // USE_ROCM } -void CUDAAllocatorConfig::parseArgs(const std::string& env) { +void CUDAAllocatorConfig::parseArgs(const std::optional& env) { // If empty, set the default values + m_max_split_size = std::numeric_limits::max(); + m_roundup_power2_divisions.assign(kRoundUpPowerOfTwoIntervals, 0); + m_garbage_collection_threshold = 0; + bool used_cudaMallocAsync = false; bool used_native_specific_option = false; - c10::CachingAllocator::ConfigTokenizer tokenizer(env); - for (size_t i = 0; i < tokenizer.size(); i++) { - const auto& key = tokenizer[i]; - if (key == "backend") { - i = parseAllocatorConfig(tokenizer, i); + if (!env.has_value()) { + return; + } + { + std::lock_guard lock(m_last_allocator_settings_mutex); + m_last_allocator_settings = env.value(); + } + + std::vector config; + lexArgs(env.value(), config); + + for (size_t i = 0; i < config.size(); i++) { + std::string_view config_item_view(config[i]); + if (config_item_view == "max_split_size_mb") { + i = parseMaxSplitSize(config, i); + used_native_specific_option = true; + } else if (config_item_view == "max_non_split_rounding_mb") { + i = parseMaxNonSplitRoundingSize(config, i); + used_native_specific_option = true; + } else if (config_item_view == "garbage_collection_threshold") { + i = parseGarbageCollectionThreshold(config, i); + used_native_specific_option = true; + } else if (config_item_view == "roundup_power2_divisions") { + i = parseRoundUpPower2Divisions(config, i); + used_native_specific_option = true; + } else if (config_item_view == "backend") { + i = parseAllocatorConfig(config, i, used_cudaMallocAsync); + } else if (config_item_view == "expandable_segments") { + used_native_specific_option = true; + consumeToken(config, ++i, ':'); + ++i; + TORCH_CHECK( + i < config.size() && + (std::string_view(config[i]) == "True" || + std::string_view(config[i]) == "False"), + "Expected a single True/False argument for expandable_segments"); + config_item_view = config[i]; + m_expandable_segments = (config_item_view == "True"); } else if ( // ROCm build's hipify step will change "cuda" to "hip", but for ease of // use, accept both. We must break up the string to prevent hipify here. - key == "release_lock_on_hipmalloc" || - key == + config_item_view == "release_lock_on_hipmalloc" || + config_item_view == "release_lock_on_c" "udamalloc") { used_native_specific_option = true; - tokenizer.checkToken(++i, ":"); - m_release_lock_on_cudamalloc = tokenizer.toBool(++i); + consumeToken(config, ++i, ':'); + ++i; + TORCH_CHECK( + i < config.size() && + (std::string_view(config[i]) == "True" || + std::string_view(config[i]) == "False"), + "Expected a single True/False argument for release_lock_on_cudamalloc"); + config_item_view = config[i]; + m_release_lock_on_cudamalloc = (config_item_view == "True"); } else if ( // ROCm build's hipify step will change "cuda" to "hip", but for ease of // use, accept both. We must break up the string to prevent hipify here. - key == "pinned_use_hip_host_register" || - key == + config_item_view == "pinned_use_hip_host_register" || + config_item_view == "pinned_use_c" "uda_host_register") { - i = parsePinnedUseCudaHostRegister(tokenizer, i); + i = parsePinnedUseCudaHostRegister(config, i); used_native_specific_option = true; - } else if (key == "pinned_num_register_threads") { - i = parsePinnedNumRegisterThreads(tokenizer, i); + } else if (config_item_view == "pinned_num_register_threads") { + i = parsePinnedNumRegisterThreads(config, i); + used_native_specific_option = true; + } else if (config_item_view == "pinned_use_background_threads") { + i = parsePinnedUseBackgroundThreads(config, i); used_native_specific_option = true; } else { - const auto& keys = - c10::CachingAllocator::AcceleratorAllocatorConfig::getKeys(); TORCH_CHECK( - keys.find(key) != keys.end(), - "Unrecognized key '", - key, - "' in Accelerator allocator config."); - i = tokenizer.skipKey(i); + false, "Unrecognized CachingAllocator option: ", config_item_view); } - if (i + 1 < tokenizer.size()) { - tokenizer.checkToken(++i, ","); + if (i + 1 < config.size()) { + consumeToken(config, ++i, ','); } } - if (m_use_async_allocator && used_native_specific_option) { + if (used_cudaMallocAsync && used_native_specific_option) { TORCH_WARN( "backend:cudaMallocAsync ignores max_split_size_mb," "roundup_power2_divisions, and garbage_collect_threshold."); @@ -123,33 +391,64 @@ void CUDAAllocatorConfig::parseArgs(const std::string& env) { } size_t CUDAAllocatorConfig::parsePinnedUseCudaHostRegister( - const c10::CachingAllocator::ConfigTokenizer& tokenizer, + const std::vector& config, size_t i) { - tokenizer.checkToken(++i, ":"); - m_pinned_use_cuda_host_register = tokenizer.toBool(++i); - + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + TORCH_CHECK( + (config[i] == "True" || config[i] == "False"), + "Expected a single True/False argument for pinned_use_cuda_host_register"); + m_pinned_use_cuda_host_register = (config[i] == "True"); + } else { + TORCH_CHECK( + false, "Error, expecting pinned_use_cuda_host_register value", ""); + } return i; } size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads( - const c10::CachingAllocator::ConfigTokenizer& tokenizer, + const std::vector& config, size_t i) { - tokenizer.checkToken(++i, ":"); - size_t val2 = tokenizer.toSizeT(++i); - TORCH_CHECK( - llvm::isPowerOf2_64(val2), - "Number of register threads has to be power of 2 ", - ""); - auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads(); - TORCH_CHECK( - val2 <= maxThreads, - "Number of register threads should be less than or equal to " + - std::to_string(maxThreads), - ""); - m_pinned_num_register_threads = val2; + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + size_t val2 = stoi(config[i]); + TORCH_CHECK( + llvm::isPowerOf2_64(val2), + "Number of register threads has to be power of 2 ", + ""); + auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads(); + TORCH_CHECK( + val2 <= maxThreads, + "Number of register threads should be less than or equal to " + + std::to_string(maxThreads), + ""); + m_pinned_num_register_threads = val2; + } else { + TORCH_CHECK( + false, "Error, expecting pinned_num_register_threads value", ""); + } + return i; +} + +size_t CUDAAllocatorConfig::parsePinnedUseBackgroundThreads( + const std::vector& config, + size_t i) { + consumeToken(config, ++i, ':'); + if (++i < config.size()) { + TORCH_CHECK( + (config[i] == "True" || config[i] == "False"), + "Expected a single True/False argument for pinned_use_background_threads"); + m_pinned_use_background_threads = (config[i] == "True"); + } else { + TORCH_CHECK( + false, "Error, expecting pinned_use_background_threads value", ""); + } return i; } -REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(CUDAAllocatorConfig) +// General caching allocator utilities +void setAllocatorSettings(const std::string& env) { + CUDACachingAllocator::CUDAAllocatorConfig::instance().parseArgs(env.c_str()); +} } // namespace c10::cuda::CUDACachingAllocator diff --git a/c10/cuda/CUDAAllocatorConfig.h b/c10/cuda/CUDAAllocatorConfig.h index 9aa128c26bd0..fda3cc02e5d0 100644 --- a/c10/cuda/CUDAAllocatorConfig.h +++ b/c10/cuda/CUDAAllocatorConfig.h @@ -1,11 +1,16 @@ #pragma once -#include #include -#include #include #include +#include +#include +#include +#include +#include +#include + namespace c10::cuda::CUDACachingAllocator { enum class Expandable_Segments_Handle_Type : int { @@ -17,28 +22,21 @@ enum class Expandable_Segments_Handle_Type : int { // Environment config parser class C10_CUDA_API CUDAAllocatorConfig { public: - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::max_split_size() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size() instead.") static size_t max_split_size() { - return c10::CachingAllocator::AcceleratorAllocatorConfig::max_split_size(); + return instance().m_max_split_size; } - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::garbage_collection_threshold() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::garbage_collection_threshold() instead.") static double garbage_collection_threshold() { - return c10::CachingAllocator::AcceleratorAllocatorConfig:: - garbage_collection_threshold(); + return instance().m_garbage_collection_threshold; } static bool expandable_segments() { - bool enabled = c10::CachingAllocator::AcceleratorAllocatorConfig:: - use_expandable_segments(); #ifndef PYTORCH_C10_DRIVER_API_SUPPORTED - if (enabled) { + if (instance().m_expandable_segments) { TORCH_WARN_ONCE("expandable_segments not supported on this platform") } return false; #else - return enabled; + return instance().m_expandable_segments; #endif } @@ -64,11 +62,8 @@ class C10_CUDA_API CUDAAllocatorConfig { return instance().m_pinned_num_register_threads; } - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::pinned_use_background_threads() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::pinned_use_background_threads() instead.") static bool pinned_use_background_threads() { - return c10::CachingAllocator::AcceleratorAllocatorConfig:: - pinned_use_background_threads(); + return instance().m_pinned_use_background_threads; } static size_t pinned_max_register_threads() { @@ -78,105 +73,92 @@ class C10_CUDA_API CUDAAllocatorConfig { return 128; } - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.") - static size_t roundup_power2_divisions(size_t size) { - return c10::CachingAllocator::AcceleratorAllocatorConfig:: - roundup_power2_divisions(size); - } + // This is used to round-up allocation size to nearest power of 2 divisions. + // More description below in function roundup_power2_next_division + // As an example, if we want 4 divisions between 2's power, this can be done + // using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4 + static size_t roundup_power2_divisions(size_t size); - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.") static std::vector roundup_power2_divisions() { - return c10::CachingAllocator::AcceleratorAllocatorConfig:: - roundup_power2_divisions(); + return instance().m_roundup_power2_divisions; } - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::max_non_split_rounding_size() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::max_non_split_rounding_size() instead.") static size_t max_non_split_rounding_size() { - return c10::CachingAllocator::AcceleratorAllocatorConfig:: - max_non_split_rounding_size(); + return instance().m_max_non_split_rounding_size; } - C10_DEPRECATED_MESSAGE( - "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::last_allocator_settings() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::last_allocator_settings() instead.") static std::string last_allocator_settings() { - return c10::CachingAllocator::getAllocatorSettings(); - } - - static bool use_async_allocator() { - return instance().m_use_async_allocator; - } - - static const std::unordered_set& getKeys() { - return instance().keys_; + std::lock_guard lock( + instance().m_last_allocator_settings_mutex); + return instance().m_last_allocator_settings; } static CUDAAllocatorConfig& instance() { static CUDAAllocatorConfig* s_instance = ([]() { auto inst = new CUDAAllocatorConfig(); - auto env = c10::utils::get_env("PYTORCH_ALLOC_CONF"); - if (!env.has_value()) { - // For backward compatibility, check for the old environment variable - // PYTORCH_CUDA_ALLOC_CONF. - env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF"); - } + auto env = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF"); #ifdef USE_ROCM // convenience for ROCm users, allow alternative HIP token if (!env.has_value()) { env = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF"); } #endif - if (env.has_value()) { - inst->parseArgs(env.value()); - } + inst->parseArgs(env); return inst; })(); return *s_instance; } - void parseArgs(const std::string& env); + void parseArgs(const std::optional& env); private: - CUDAAllocatorConfig() = default; - - size_t parseAllocatorConfig( - const c10::CachingAllocator::ConfigTokenizer& tokenizer, + CUDAAllocatorConfig(); + + static void lexArgs(const std::string& env, std::vector& config); + static void consumeToken( + const std::vector& config, + size_t i, + const char c); + size_t parseMaxSplitSize(const std::vector& config, size_t i); + size_t parseMaxNonSplitRoundingSize( + const std::vector& config, + size_t i); + size_t parseGarbageCollectionThreshold( + const std::vector& config, + size_t i); + size_t parseRoundUpPower2Divisions( + const std::vector& config, size_t i); + size_t parseAllocatorConfig( + const std::vector& config, + size_t i, + bool& used_cudaMallocAsync); size_t parsePinnedUseCudaHostRegister( - const c10::CachingAllocator::ConfigTokenizer& tokenizer, + const std::vector& config, size_t i); size_t parsePinnedNumRegisterThreads( - const c10::CachingAllocator::ConfigTokenizer& tokenizer, + const std::vector& config, + size_t i); + size_t parsePinnedUseBackgroundThreads( + const std::vector& config, size_t i); - std::atomic m_pinned_num_register_threads{1}; - std::atomic m_expandable_segments_handle_type -#if CUDA_VERSION >= 12030 - {Expandable_Segments_Handle_Type::UNSPECIFIED}; -#else - {Expandable_Segments_Handle_Type::POSIX_FD}; -#endif - std::atomic m_release_lock_on_cudamalloc{false}; - std::atomic m_pinned_use_cuda_host_register{false}; - std::atomic m_use_async_allocator{false}; - std::atomic m_is_allocator_loaded{false}; - std::unordered_set keys_{ - "backend", - // keep BC for Rocm: `cuda` -> `cud` `a`, to avoid hipify issues - // NOLINTBEGIN(bugprone-suspicious-missing-comma,-warnings-as-errors) - "release_lock_on_cud" - "amalloc", - "pinned_use_cud" - "a_host_register", - // NOLINTEND(bugprone-suspicious-missing-comma,-warnings-as-errors) - "release_lock_on_hipmalloc", - "pinned_use_hip_host_register", - "pinned_num_register_threads"}; + std::atomic m_max_split_size; + std::atomic m_max_non_split_rounding_size; + std::vector m_roundup_power2_divisions; + std::atomic m_garbage_collection_threshold; + std::atomic m_pinned_num_register_threads; + std::atomic m_expandable_segments; + std::atomic + m_expandable_segments_handle_type; + std::atomic m_release_lock_on_cudamalloc; + std::atomic m_pinned_use_cuda_host_register; + std::atomic m_pinned_use_background_threads; + std::string m_last_allocator_settings; + std::mutex m_last_allocator_settings_mutex; }; -// Keep this for backwards compatibility -using c10::CachingAllocator::setAllocatorSettings; +// General caching allocator utilities +C10_CUDA_API void setAllocatorSettings(const std::string& env); } // namespace c10::cuda::CUDACachingAllocator diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 5ae04bcd3f53..4d58c11c5c9b 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -63,6 +64,10 @@ namespace cuda::CUDACachingAllocator { using namespace c10::CachingAllocator; using namespace c10::CachingDeviceAllocator; +// Included here as this is externally used in CUDAAllocatorConfig +const size_t kLargeBuffer = + 20971520; // "large" allocations may be packed in 20 MiB blocks + namespace Native { // @@ -1226,7 +1231,7 @@ class DeviceCachingAllocator { DeviceCachingAllocator() : large_blocks(/*small=*/false), small_blocks(/*small=*/true) { stats.max_split_size = - static_cast(AcceleratorAllocatorConfig::max_split_size()); + static_cast(CUDAAllocatorConfig::max_split_size()); context_recorder_.store(nullptr); } @@ -1351,8 +1356,7 @@ class DeviceCachingAllocator { // Do garbage collection if the flag is set. if (C10_UNLIKELY( set_fraction && - AcceleratorAllocatorConfig::garbage_collection_threshold() > - 0.0)) { + CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) { garbage_collect_cached_blocks(context); } // Attempt allocate @@ -1604,7 +1608,7 @@ class DeviceCachingAllocator { stats.active_bytes[stat_type].increase(block->size); stats.requested_bytes[stat_type].increase(block->requested_size); }); - if (block->size >= AcceleratorAllocatorConfig::max_split_size()) + if (block->size >= CUDAAllocatorConfig::max_split_size()) stats.oversize_allocations.increase(1); auto allocated_bytes_gauge = @@ -1655,7 +1659,7 @@ class DeviceCachingAllocator { block->pool->owner_MempoolId(), context ? context : block->context_when_allocated); - if (block->size >= AcceleratorAllocatorConfig::max_split_size()) + if (block->size >= CUDAAllocatorConfig::max_split_size()) stats.oversize_allocations.decrease(1); if (!block->stream_uses.empty()) { @@ -2205,8 +2209,7 @@ class DeviceCachingAllocator { if (size < kMinBlockSize) { return kMinBlockSize; } else { - auto divisions = - AcceleratorAllocatorConfig::roundup_power2_divisions(size); + auto divisions = CUDAAllocatorConfig::roundup_power2_divisions(size); if (divisions > 1 && size > (kMinBlockSize * divisions)) { return roundup_power2_next_division(size, divisions); } else { @@ -2696,7 +2699,7 @@ class DeviceCachingAllocator { if (block->pool->is_small || CUDAAllocatorConfig::expandable_segments()) { return remaining >= kMinBlockSize; } else { - return (size < AcceleratorAllocatorConfig::max_split_size()) && + return (size < CUDAAllocatorConfig::max_split_size()) && (remaining > kSmallSize); } } @@ -2716,7 +2719,7 @@ class DeviceCachingAllocator { if (C10_UNLIKELY( set_fraction && - AcceleratorAllocatorConfig::garbage_collection_threshold() > 0.0)) { + CUDAAllocatorConfig::garbage_collection_threshold() > 0.0)) { // Track block reuse interval only when garbage collection is enabled. ++pool.get_free_blocks_call_count; } @@ -2758,13 +2761,13 @@ class DeviceCachingAllocator { } // Do not return an oversized block for a large request - if ((p.size() < AcceleratorAllocatorConfig::max_split_size()) && - ((*it)->size >= AcceleratorAllocatorConfig::max_split_size())) + if ((p.size() < CUDAAllocatorConfig::max_split_size()) && + ((*it)->size >= CUDAAllocatorConfig::max_split_size())) return false; // Allow oversized block size to be rounded up but within a limit - if ((p.size() >= AcceleratorAllocatorConfig::max_split_size()) && + if ((p.size() >= CUDAAllocatorConfig::max_split_size()) && ((*it)->size >= - p.size() + AcceleratorAllocatorConfig::max_non_split_rounding_size())) + p.size() + CUDAAllocatorConfig::max_non_split_rounding_size())) return false; p.block = *it; pool.blocks.erase(it); @@ -2787,7 +2790,7 @@ class DeviceCachingAllocator { // therefore should be of less overheads. size_t gc_threshold = static_cast( - AcceleratorAllocatorConfig::garbage_collection_threshold() * + CUDAAllocatorConfig::garbage_collection_threshold() * static_cast(allowed_memory_maximum)); // No need to trigger GC yet if (total_allocated_memory <= gc_threshold) { @@ -2935,7 +2938,7 @@ class DeviceCachingAllocator { stats.segment[stat_type].increase(1); stats.reserved_bytes[stat_type].increase(size); }); - if (size >= AcceleratorAllocatorConfig::max_split_size()) + if (size >= CUDAAllocatorConfig::max_split_size()) stats.oversize_segments.increase(1); auto reserved_bytes_gauge = STATIC_GAUGE(pytorch.CUDACachingAllocator.reserved_bytes); @@ -2964,7 +2967,7 @@ class DeviceCachingAllocator { bool release_available_cached_blocks( const AllocParams& p, const std::shared_ptr& context) { - if (AcceleratorAllocatorConfig::max_split_size() == + if (CUDAAllocatorConfig::max_split_size() == std::numeric_limits::max()) return false; BlockPool& pool = *p.pool; @@ -2972,8 +2975,8 @@ class DeviceCachingAllocator { // because of std::unique_ptr, block cannot be trivially copied // Use constructor for search key. Block key(p.search_key.device, p.search_key.stream, p.search_key.size); - key.size = (key.size < AcceleratorAllocatorConfig::max_split_size()) - ? AcceleratorAllocatorConfig::max_split_size() + key.size = (key.size < CUDAAllocatorConfig::max_split_size()) + ? CUDAAllocatorConfig::max_split_size() : key.size; auto it = pool.blocks.lower_bound(&key); if (it == pool.blocks.end() || (*it)->stream != p.stream() || @@ -2986,7 +2989,7 @@ class DeviceCachingAllocator { --it; // Back up one item. Now on the largest block for the correct // stream while ((totalReleased < key.size) && - ((*it)->size >= AcceleratorAllocatorConfig::max_split_size()) && + ((*it)->size >= CUDAAllocatorConfig::max_split_size()) && ((*it)->stream == p.stream())) { auto cur = it; bool is_first = cur == pool.blocks.begin(); @@ -3111,7 +3114,7 @@ class DeviceCachingAllocator { stats.reserved_bytes[static_cast(StatType::AGGREGATE)] .current); - if (block->size >= AcceleratorAllocatorConfig::max_split_size()) + if (block->size >= CUDAAllocatorConfig::max_split_size()) stats.oversize_segments.decrease(1); pool->blocks.erase(block); delete block; @@ -3738,8 +3741,8 @@ class NativeCachingAllocator : public CUDAAllocator { auto& md = result.config_metadata; md.garbage_collection_threshold = - AcceleratorAllocatorConfig::garbage_collection_threshold(); - md.max_split_size = AcceleratorAllocatorConfig::max_split_size(); + CUDAAllocatorConfig::garbage_collection_threshold(); + md.max_split_size = CUDAAllocatorConfig::max_split_size(); md.pinned_num_register_threads = CUDAAllocatorConfig::pinned_num_register_threads(); md.expandable_segments = CUDAAllocatorConfig::expandable_segments(); @@ -3747,10 +3750,9 @@ class NativeCachingAllocator : public CUDAAllocator { CUDAAllocatorConfig::release_lock_on_cudamalloc(); md.pinned_use_host_register = CUDAAllocatorConfig::pinned_use_cuda_host_register(); - md.last_allocator_settings = - AcceleratorAllocatorConfig::last_allocator_settings(); + md.last_allocator_settings = CUDAAllocatorConfig::last_allocator_settings(); md.roundup_power2_divisions = - AcceleratorAllocatorConfig::roundup_power2_divisions(); + CUDAAllocatorConfig::roundup_power2_divisions(); return result; } @@ -4128,10 +4130,49 @@ CUDAAllocator* allocator(); } // namespace CudaMallocAsync struct BackendStaticInitializer { + // Parses env for backend at load time, duplicating some logic from + // CUDAAllocatorConfig. CUDAAllocatorConfig double-checks it later (at + // runtime). Defers verbose exceptions and error checks, including Cuda + // version checks, to CUDAAllocatorConfig's runtime doublecheck. If this + // works, maybe we should move all of CUDAAllocatorConfig here? CUDAAllocator* parseEnvForBackend() { - // If the environment variable is set, we use the CudaMallocAsync allocator. - if (CUDAAllocatorConfig::use_async_allocator()) { - return CudaMallocAsync::allocator(); + auto val = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF"); +#ifdef USE_ROCM + // convenience for ROCm users to allow either CUDA or HIP env var + if (!val.has_value()) { + val = c10::utils::get_env("PYTORCH_HIP_ALLOC_CONF"); + } +#endif + if (val.has_value()) { + const std::string& config = val.value(); + + std::regex exp("[\\s,]+"); + std::sregex_token_iterator it(config.begin(), config.end(), exp, -1); + std::sregex_token_iterator end; + std::vector options(it, end); + + for (auto option : options) { + std::regex exp2("[:]+"); + std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1); + std::sregex_token_iterator end2; + std::vector kv(it2, end2); + if (kv.size() >= 2) { + if (kv[0] == "backend") { +#ifdef USE_ROCM + // convenience for ROCm users to allow either CUDA or HIP env var + if (kv[1] == + "cud" + "aMallocAsync" || + kv[1] == "hipMallocAsync") +#else + if (kv[1] == "cudaMallocAsync") +#endif + return CudaMallocAsync::allocator(); + if (kv[1] == "native") + return &Native::allocator; + } + } + } } return &Native::allocator; } diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 956411fe2282..a6fa61110d67 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include #include @@ -50,9 +49,10 @@ namespace c10::cuda::CUDACachingAllocator { // Preserved only for BC reasons // NOLINTNEXTLINE(misc-unused-using-decls) -using c10::CachingAllocator::kLargeBuffer; using c10::CachingDeviceAllocator::DeviceStats; +extern const size_t kLargeBuffer; + typedef std::shared_ptr (*CreateContextFn)(); // Struct containing info of an allocation block (i.e. a fractional part of a diff --git a/c10/cuda/CUDAException.cpp b/c10/cuda/CUDAException.cpp index 05f00e43a2a7..457d35f020bb 100644 --- a/c10/cuda/CUDAException.cpp +++ b/c10/cuda/CUDAException.cpp @@ -28,7 +28,9 @@ void c10_cuda_check_implementation( std::string check_message; #ifndef STRIP_ERROR_MESSAGES check_message.append("CUDA error: "); - check_message.append(cudaGetErrorString(cuda_error)); + const char* error_string = cudaGetErrorString(cuda_error); + check_message.append(error_string); + check_message.append(c10::cuda::get_cuda_error_help(cuda_error)); check_message.append(c10::cuda::get_cuda_check_suffix()); check_message.append("\n"); if (include_device_assertions) { diff --git a/c10/cuda/CUDAFunctions.h b/c10/cuda/CUDAFunctions.h index 2c7aa99feeb3..543c86602746 100644 --- a/c10/cuda/CUDAFunctions.h +++ b/c10/cuda/CUDAFunctions.h @@ -90,8 +90,17 @@ C10_CUDA_API void __inline__ memcpy_and_sync( (*interp)->trace_gpu_stream_synchronization( c10::kCUDA, reinterpret_cast(stream)); } -#if defined(TORCH_HIP_VERSION) && (TORCH_HIP_VERSION >= 301) - C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream)); +#if defined(USE_ROCM) && USE_ROCM + // As of ROCm 6.4.1, HIP runtime does not raise an error during capture of + // hipMemcpyWithStream which is a synchronous call. Thus, we add a check + // here explicitly. + hipStreamCaptureStatus captureStatus; + C10_CUDA_CHECK(hipStreamGetCaptureInfo(stream, &captureStatus, nullptr)); + if (C10_LIKELY(captureStatus == hipStreamCaptureStatusNone)) { + C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream)); + } else { + C10_CUDA_CHECK(hipErrorStreamCaptureUnsupported); + } #else C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream)); C10_CUDA_CHECK(cudaStreamSynchronize(stream)); diff --git a/c10/cuda/CUDAMiscFunctions.cpp b/c10/cuda/CUDAMiscFunctions.cpp index cc6519728f1e..b1b6170f891e 100644 --- a/c10/cuda/CUDAMiscFunctions.cpp +++ b/c10/cuda/CUDAMiscFunctions.cpp @@ -1,8 +1,30 @@ #include #include +#include +#include +#include namespace c10::cuda { +// Explain common CUDA errors +// NOLINTNEXTLINE(bugprone-exception-escape,-warnings-as-errors) +std::string get_cuda_error_help(cudaError_t error) noexcept { + std::string help_text; + switch (error) { + case cudaErrorInvalidDevice: + help_text.append( + "\nGPU device may be out of range, do you have enough GPUs?"); + break; + default: + help_text.append("\nSearch for `") + .append(cudaGetErrorName(error)) + .append( + "' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information."); + break; + } + return help_text; +} + // NOLINTNEXTLINE(bugprone-exception-escape,-warnings-as-errors) const char* get_cuda_check_suffix() noexcept { static auto device_blocking_flag = diff --git a/c10/cuda/CUDAMiscFunctions.h b/c10/cuda/CUDAMiscFunctions.h index dc3fced770ba..ec1114935457 100644 --- a/c10/cuda/CUDAMiscFunctions.h +++ b/c10/cuda/CUDAMiscFunctions.h @@ -3,10 +3,13 @@ // CUDAExceptions.h #include +#include #include +#include namespace c10::cuda { +C10_CUDA_API std::string get_cuda_error_help(cudaError_t) noexcept; C10_CUDA_API const char* get_cuda_check_suffix() noexcept; C10_CUDA_API std::mutex* getFreeMutex(); } // namespace c10::cuda diff --git a/c10/cuda/driver_api.cpp b/c10/cuda/driver_api.cpp index bb201b5c0397..f4b62e53fcc0 100644 --- a/c10/cuda/driver_api.cpp +++ b/c10/cuda/driver_api.cpp @@ -1,30 +1,35 @@ #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) +#include #include #include #include +#include +#include #include namespace c10::cuda { namespace { +void* get_symbol(const char* name, int version); + DriverAPI create_driver_api() { - void* handle_0 = dlopen("libcuda.so.1", RTLD_LAZY | RTLD_NOLOAD); - TORCH_CHECK(handle_0, "Can't open libcuda.so.1: ", dlerror()); void* handle_1 = DriverAPI::get_nvml_handle(); DriverAPI r{}; -#define LOOKUP_LIBCUDA_ENTRY(name) \ - r.name##_ = ((decltype(&name))dlsym(handle_0, #name)); \ - TORCH_INTERNAL_ASSERT(r.name##_, "Can't find ", #name, ": ", dlerror()) - C10_LIBCUDA_DRIVER_API(LOOKUP_LIBCUDA_ENTRY) -#undef LOOKUP_LIBCUDA_ENTRY +#define LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_REQUIRED(name, version) \ + r.name##_ = reinterpret_cast(get_symbol(#name, version)); \ + TORCH_INTERNAL_ASSERT(r.name##_, "Can't find ", #name); + C10_LIBCUDA_DRIVER_API_REQUIRED(LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_REQUIRED) +#undef LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_REQUIRED -#define LOOKUP_LIBCUDA_ENTRY(name) \ - r.name##_ = ((decltype(&name))dlsym(handle_0, #name)); \ - dlerror(); - C10_LIBCUDA_DRIVER_API_12030(LOOKUP_LIBCUDA_ENTRY) -#undef LOOKUP_LIBCUDA_ENTRY +// Users running drivers between 12.0 and 12.3 will not have these symbols, +// they would be resolved into nullptr, but we guard their usage at runtime +// to ensure safe fallback behavior. +#define LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_OPTIONAL(name, version) \ + r.name##_ = reinterpret_cast(get_symbol(#name, version)); + C10_LIBCUDA_DRIVER_API_OPTIONAL(LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_OPTIONAL) +#undef LOOKUP_LIBCUDA_ENTRY_WITH_VERSION_OPTIONAL if (handle_1) { #define LOOKUP_NVML_ENTRY(name) \ @@ -35,6 +40,32 @@ DriverAPI create_driver_api() { } return r; } + +void* get_symbol(const char* name, int version) { + void* out = nullptr; + cudaDriverEntryPointQueryResult qres{}; + + // CUDA 12.5+ supports version-based lookup +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12050) + if (auto st = cudaGetDriverEntryPointByVersion( + name, &out, version, cudaEnableDefault, &qres); + st == cudaSuccess && qres == cudaDriverEntryPointSuccess && out) { + return out; + } +#endif + + // This fallback to the old API to try getting the symbol again. + if (auto st = cudaGetDriverEntryPoint(name, &out, cudaEnableDefault, &qres); + st == cudaSuccess && qres == cudaDriverEntryPointSuccess && out) { + return out; + } + + // If the symbol cannot be resolved, report and return nullptr; + // the caller is responsible for checking the pointer. + LOG(INFO) << "Failed to resolve symbol " << name; + return nullptr; +} + } // namespace void* DriverAPI::get_nvml_handle() { diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index a8ded9de68d7..9800809d1e53 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -20,29 +20,42 @@ } \ } while (0) -#define C10_LIBCUDA_DRIVER_API(_) \ - _(cuDeviceGetAttribute) \ - _(cuMemAddressReserve) \ - _(cuMemRelease) \ - _(cuMemMap) \ - _(cuMemAddressFree) \ - _(cuMemSetAccess) \ - _(cuMemUnmap) \ - _(cuMemCreate) \ - _(cuMemGetAllocationGranularity) \ - _(cuMemExportToShareableHandle) \ - _(cuMemImportFromShareableHandle) \ - _(cuMemsetD32Async) \ - _(cuStreamWriteValue32) \ - _(cuGetErrorString) +// The integer in the second column specifies the requested CUDA Driver API +// version. The dynamic loader will accept a driver with a newer version, but it +// ensures that the requested symbol exists in *at least* the specified version +// or earlier. + +// Keep these requested versions as low as possible to maximize compatibility +// across different driver versions. + +// Why do we pin to an older version instead of using the latest? +// If a user installs a newer driver, blindly resolving the symbol may bind to a +// newer version of the function with different behavior, potentially breaking +// PyTorch. + +#define C10_LIBCUDA_DRIVER_API_REQUIRED(_) \ + _(cuDeviceGetAttribute, 12000) \ + _(cuMemAddressReserve, 12000) \ + _(cuMemRelease, 12000) \ + _(cuMemMap, 12000) \ + _(cuMemAddressFree, 12000) \ + _(cuMemSetAccess, 12000) \ + _(cuMemUnmap, 12000) \ + _(cuMemCreate, 12000) \ + _(cuMemGetAllocationGranularity, 12000) \ + _(cuMemExportToShareableHandle, 12000) \ + _(cuMemImportFromShareableHandle, 12000) \ + _(cuMemsetD32Async, 12000) \ + _(cuStreamWriteValue32, 12000) \ + _(cuGetErrorString, 12000) #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030) -#define C10_LIBCUDA_DRIVER_API_12030(_) \ - _(cuMulticastAddDevice) \ - _(cuMulticastBindMem) \ - _(cuMulticastCreate) +#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) \ + _(cuMulticastAddDevice, 12030) \ + _(cuMulticastBindMem, 12030) \ + _(cuMulticastCreate, 12030) #else -#define C10_LIBCUDA_DRIVER_API_12030(_) +#define C10_LIBCUDA_DRIVER_API_OPTIONAL(_) #endif #define C10_NVML_DRIVER_API(_) \ @@ -56,11 +69,14 @@ namespace c10::cuda { struct DriverAPI { +#define CREATE_MEMBER_VERSIONED(name, version) decltype(&name) name##_; #define CREATE_MEMBER(name) decltype(&name) name##_; - C10_LIBCUDA_DRIVER_API(CREATE_MEMBER) - C10_LIBCUDA_DRIVER_API_12030(CREATE_MEMBER) + C10_LIBCUDA_DRIVER_API_REQUIRED(CREATE_MEMBER_VERSIONED) + C10_LIBCUDA_DRIVER_API_OPTIONAL(CREATE_MEMBER_VERSIONED) C10_NVML_DRIVER_API(CREATE_MEMBER) +#undef CREATE_MEMBER_VERSIONED #undef CREATE_MEMBER + static DriverAPI* get(); static void* get_nvml_handle(); }; diff --git a/c10/macros/Export.h b/c10/macros/Export.h index b013910902b2..1b8a6811c53f 100644 --- a/c10/macros/Export.h +++ b/c10/macros/Export.h @@ -1,78 +1 @@ -#ifndef C10_MACROS_EXPORT_H_ -#define C10_MACROS_EXPORT_H_ - -#ifndef C10_USING_CUSTOM_GENERATED_MACROS -#include -#endif // C10_USING_CUSTOM_GENERATED_MACROS - #include - -// This one is being used by libtorch.so -#ifdef CAFFE2_BUILD_MAIN_LIB -#define TORCH_API C10_EXPORT -#else -#define TORCH_API C10_IMPORT -#endif - -// You may be wondering: Whose brilliant idea was it to split torch_cuda into -// two pieces with confusing names? -// Once upon a time, there _was_ only TORCH_CUDA_API. All was happy until we -// tried to compile PyTorch for CUDA 11.1, which ran into relocation marker -// issues when linking big binaries. -// (https://github.com/pytorch/pytorch/issues/39968) We had two choices: -// (1) Stop supporting so many GPU architectures -// (2) Do something else -// We chose #2 and decided to split the behemoth that was torch_cuda into two -// smaller libraries, one with most of the core kernel functions (torch_cuda_cu) -// and the other that had..well..everything else (torch_cuda_cpp). The idea was -// this: instead of linking our static libraries (like the hefty -// libcudnn_static.a) with another huge library, torch_cuda, and run into pesky -// relocation marker issues, we could link our static libraries to a smaller -// part of torch_cuda (torch_cuda_cpp) and avoid the issues. - -// libtorch_cuda_cu.so -#ifdef TORCH_CUDA_CU_BUILD_MAIN_LIB -#define TORCH_CUDA_CU_API C10_EXPORT -#elif defined(BUILD_SPLIT_CUDA) -#define TORCH_CUDA_CU_API C10_IMPORT -#endif - -// libtorch_cuda_cpp.so -#ifdef TORCH_CUDA_CPP_BUILD_MAIN_LIB -#define TORCH_CUDA_CPP_API C10_EXPORT -#elif defined(BUILD_SPLIT_CUDA) -#define TORCH_CUDA_CPP_API C10_IMPORT -#endif - -// libtorch_cuda.so (where torch_cuda_cu and torch_cuda_cpp are a part of the -// same api) -#ifdef TORCH_CUDA_BUILD_MAIN_LIB -#define TORCH_CUDA_CPP_API C10_EXPORT -#define TORCH_CUDA_CU_API C10_EXPORT -#elif !defined(BUILD_SPLIT_CUDA) -#define TORCH_CUDA_CPP_API C10_IMPORT -#define TORCH_CUDA_CU_API C10_IMPORT -#endif - -#if defined(TORCH_HIP_BUILD_MAIN_LIB) -#define TORCH_HIP_CPP_API C10_EXPORT -#define TORCH_HIP_API C10_EXPORT -#else -#define TORCH_HIP_CPP_API C10_IMPORT -#define TORCH_HIP_API C10_IMPORT -#endif - -#if defined(TORCH_XPU_BUILD_MAIN_LIB) -#define TORCH_XPU_API C10_EXPORT -#else -#define TORCH_XPU_API C10_IMPORT -#endif - -// Enums only need to be exported on windows for non-CUDA files -#if defined(_WIN32) && defined(__CUDACC__) -#define C10_API_ENUM C10_API -#else -#define C10_API_ENUM -#endif - -#endif // C10_MACROS_EXPORT_H_ diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index 6b51a39f2a94..87ebc4f422c4 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -1,548 +1 @@ -#ifndef C10_MACROS_MACROS_H_ -#define C10_MACROS_MACROS_H_ -#include - -/* Main entry for c10/macros. - * - * In your code, include c10/macros/Macros.h directly, instead of individual - * files in this folder. - */ - -// For build systems that do not directly depend on CMake and directly build -// from the source directory (such as Buck), one may not have a cmake_macros.h -// file at all. In this case, the build system is responsible for providing -// correct macro definitions corresponding to the cmake_macros.h.in file. -// -// In such scenarios, one should define the macro -// C10_USING_CUSTOM_GENERATED_MACROS -// to inform this header that it does not need to include the cmake_macros.h -// file. - -#ifndef C10_USING_CUSTOM_GENERATED_MACROS -#include -#endif // C10_USING_CUSTOM_GENERATED_MACROS - -#include - -#if defined(__clang__) -#define __ubsan_ignore_float_divide_by_zero__ \ - __attribute__((no_sanitize("float-divide-by-zero"))) -#define __ubsan_ignore_undefined__ __attribute__((no_sanitize("undefined"))) -#define __ubsan_ignore_signed_int_overflow__ \ - __attribute__((no_sanitize("signed-integer-overflow"))) -#define __ubsan_ignore_pointer_overflow__ \ - __attribute__((no_sanitize("pointer-overflow"))) -#define __ubsan_ignore_function__ __attribute__((no_sanitize("function"))) -#define __ubsan_ignore_float_cast_overflow__ \ - __attribute__((no_sanitize("float-cast-overflow"))) -#else -#define __ubsan_ignore_float_divide_by_zero__ -#define __ubsan_ignore_undefined__ -#define __ubsan_ignore_signed_int_overflow__ -#define __ubsan_ignore_pointer_overflow__ -#define __ubsan_ignore_function__ -#define __ubsan_ignore_float_cast_overflow__ -#endif - -// Detect address sanitizer as some stuff doesn't work with it -#undef C10_ASAN_ENABLED - -// for clang -#if defined(__has_feature) -#if ((__has_feature(address_sanitizer))) -#define C10_ASAN_ENABLED 1 -#endif -#endif - -// for gcc -#if defined(__SANITIZE_ADDRESS__) -#if __SANITIZE_ADDRESS__ -#if !defined(C10_ASAN_ENABLED) -#define C10_ASAN_ENABLED 1 -#endif -#endif -#endif - -#if !defined(C10_ASAN_ENABLED) -#define C10_ASAN_ENABLED 0 -#endif - -// Detect undefined-behavior sanitizer (UBSAN) -#undef C10_UBSAN_ENABLED - -// for clang or gcc >= 14 -// NB: gcc 14 adds support for Clang's __has_feature -// https://gcc.gnu.org/gcc-14/changes.html -// gcc < 14 doesn't have a macro for UBSAN -// (e.g. __SANITIZE_UNDEFINED__ does not exist in gcc) -// https://github.com/google/sanitizers/issues/765 -#if defined(__has_feature) -#if ((__has_feature(undefined_behavior_sanitizer))) -#define C10_UBSAN_ENABLED 1 -#endif -#endif - -#if !defined(C10_UBSAN_ENABLED) -#define C10_UBSAN_ENABLED 0 -#endif - -// Disable the copy and assignment operator for a class. Note that this will -// disable the usage of the class in std containers. -#define C10_DISABLE_COPY_AND_ASSIGN(classname) \ - classname(const classname&) = delete; \ - classname& operator=(const classname&) = delete - -#define C10_CONCATENATE_IMPL(s1, s2) s1##s2 -#define C10_CONCATENATE(s1, s2) C10_CONCATENATE_IMPL(s1, s2) - -#define C10_MACRO_EXPAND(args) args - -#define C10_STRINGIZE_IMPL(x) #x -#define C10_STRINGIZE(x) C10_STRINGIZE_IMPL(x) - -/** - * C10_ANONYMOUS_VARIABLE(str) introduces a new identifier which starts with - * str and ends with a unique number. - */ -#ifdef __COUNTER__ -#define C10_UID __COUNTER__ -#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __COUNTER__) -#else -#define C10_UID __LINE__ -#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __LINE__) -#endif - -#ifdef __has_cpp_attribute -#define C10_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x) -#else -#define C10_HAS_CPP_ATTRIBUTE(x) (0) -#endif - -#ifndef FBCODE_CAFFE2 -/// DEPRECATED: Warn if a type or return value is discarded. -#define C10_NODISCARD [[nodiscard]] - -/// DEPRECATED: Suppress an unused variable. -#define C10_UNUSED [[maybe_unused]] -#endif - -#if !defined(__has_attribute) -#define __has_attribute(x) 0 -#endif - -// Direct port of LLVM_ATTRIBUTE_USED. -#if __has_attribute(used) -#define C10_USED __attribute__((__used__)) -#else -#define C10_USED -#endif - -#define C10_RESTRICT __restrict - -// Simply define the namespace, in case a dependent library want to refer to -// the c10 namespace but not any nontrivial files. -namespace c10 {} -namespace c10::cuda {} -namespace c10::hip {} -namespace c10::xpu {} - -// Since C10 is the core library for caffe2 (and aten), we will simply reroute -// all abstractions defined in c10 to be available in caffe2 as well. -// This is only for backwards compatibility. Please use the symbols from the -// c10 namespace where possible. -namespace caffe2 { -using namespace c10; -} -namespace at { -using namespace c10; -} -namespace at::cuda { -using namespace c10::cuda; -} // namespace at::cuda - -// WARNING!!! THIS IS A GIANT HACK!!! -// This line means you cannot simultaneously include c10/hip -// and c10/cuda and then use them from the at::cuda namespace. -// This is true in practice, because HIPIFY works inplace on -// files in ATen/cuda, so it assumes that c10::hip is available -// from at::cuda. This namespace makes that happen. When -// HIPIFY is no longer out-of-place, we can switch the cuda -// here to hip and everyone is happy. -namespace at::cuda { -using namespace c10::hip; -} // namespace at::cuda - -namespace at::xpu { -using namespace c10::xpu; -} // namespace at::xpu - -// C10_LIKELY/C10_UNLIKELY -// -// These macros provide parentheses, so you can use these macros as: -// -// if C10_LIKELY(some_expr) { -// ... -// } -// -// NB: static_cast to boolean is mandatory in C++, because __builtin_expect -// takes a long argument, which means you may trigger the wrong conversion -// without it. -// -#if defined(__GNUC__) || defined(__ICL) || defined(__clang__) -#define C10_LIKELY(expr) (__builtin_expect(static_cast(expr), 1)) -#define C10_UNLIKELY(expr) (__builtin_expect(static_cast(expr), 0)) -#else -#define C10_LIKELY(expr) (expr) -#define C10_UNLIKELY(expr) (expr) -#endif - -/// C10_NOINLINE - Functions whose declaration is annotated with this will not -/// be inlined. -#ifdef __GNUC__ -#define C10_NOINLINE __attribute__((noinline)) -#elif _MSC_VER -#define C10_NOINLINE __declspec(noinline) -#else -#define C10_NOINLINE -#endif - -#if defined(_MSC_VER) -#define C10_ALWAYS_INLINE __forceinline -#elif __has_attribute(always_inline) || defined(__GNUC__) -#define C10_ALWAYS_INLINE __attribute__((__always_inline__)) inline -#else -#define C10_ALWAYS_INLINE inline -#endif - -// Unlike C10_ALWAYS_INLINE, C10_ALWAYS_INLINE_ATTRIBUTE can be used -// on a lambda. -#if defined(_MSC_VER) -// MSVC 14.39 is reasonably recent and doesn't like -// [[msvc::forceinline]] on a lambda, so don't try to use it. -#define C10_ALWAYS_INLINE_ATTRIBUTE -#elif __has_attribute(always_inline) || defined(__GNUC__) -#define C10_ALWAYS_INLINE_ATTRIBUTE __attribute__((__always_inline__)) -#else -#define C10_ALWAYS_INLINE_ATTRIBUTE -#endif - -#if defined(_MSC_VER) -#define C10_ATTR_VISIBILITY_HIDDEN -#elif defined(__GNUC__) -#define C10_ATTR_VISIBILITY_HIDDEN __attribute__((__visibility__("hidden"))) -#else -#define C10_ATTR_VISIBILITY_HIDDEN -#endif - -#define C10_ERASE C10_ALWAYS_INLINE C10_ATTR_VISIBILITY_HIDDEN - -#include - -#ifdef __HIPCC__ -// Unlike CUDA, HIP requires a HIP header to be included for __host__ to work. -// We do this #include here so that C10_HOST_DEVICE and friends will Just Work. -// See https://github.com/ROCm/hip/issues/441 -#include -#endif - -#if defined(__CUDACC__) || defined(__HIPCC__) -// Designates functions callable from the host (CPU) and the device (GPU) -#define C10_HOST_DEVICE __host__ __device__ -#define C10_DEVICE __device__ -#define C10_HOST __host__ -// constants from -// (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications) -// The maximum number of threads per multiprocessor is 1024 for Turing -// architecture (7.5), 1536 for Geforce Ampere (8.6)/Jetson Orin (8.7), and -// 2048 for all other architectures. You'll get warnings if you exceed these -// constants. Hence, the following macros adjust the input values from the user -// to resolve potential warnings. -#if __CUDA_ARCH__ == 750 -constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1024; -#elif __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 890 -constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1536; -#else -constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 2048; -#endif -// CUDA_MAX_THREADS_PER_BLOCK is same for all architectures currently -constexpr uint32_t CUDA_MAX_THREADS_PER_BLOCK = 1024; -// CUDA_THREADS_PER_BLOCK_FALLBACK is the "canonical fallback" choice of block -// size. 256 is a good number for this fallback and should give good occupancy -// and versatility across all architectures. -constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; -// NOTE: if you are thinking of constexpr-ify the inputs to launch bounds, it -// turns out that although __launch_bounds__ can take constexpr, it -// can't take a constexpr that has anything to do with templates. -// Currently we use launch_bounds that depend on template arguments in -// Loops.cuh, Reduce.cuh and LossCTC.cuh. Hence, C10_MAX_THREADS_PER_BLOCK -// and C10_MIN_BLOCKS_PER_SM are kept as macros. -// Suppose you were planning to write __launch_bounds__(a, b), based on your -// performance tuning on a modern GPU. Instead, you should write -// __launch_bounds__(C10_MAX_THREADS_PER_BLOCK(a), C10_MIN_BLOCKS_PER_SM(a, b)), -// which will also properly respect limits on old architectures. -#define C10_MAX_THREADS_PER_BLOCK(val) \ - (((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) \ - : CUDA_THREADS_PER_BLOCK_FALLBACK) -#define C10_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) \ - ((((threads_per_block) * (blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) \ - ? (blocks_per_sm) \ - : ((CUDA_MAX_THREADS_PER_SM + (threads_per_block) - 1) / \ - (threads_per_block)))) -// C10_LAUNCH_BOUNDS is analogous to __launch_bounds__ -#define C10_LAUNCH_BOUNDS_0 \ - __launch_bounds__( \ - 256, 4) // default launch bounds that should give good occupancy and - // versatility across all architectures. -#define C10_LAUNCH_BOUNDS_1(max_threads_per_block) \ - __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block)))) -#define C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) \ - __launch_bounds__( \ - (C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))), \ - (C10_MIN_BLOCKS_PER_SM((max_threads_per_block), (min_blocks_per_sm)))) -#else -#define C10_HOST_DEVICE -#define C10_HOST -#define C10_DEVICE -#endif - -#if defined(USE_ROCM) -#define C10_HIP_HOST_DEVICE __host__ __device__ -#else -#define C10_HIP_HOST_DEVICE -#endif - -#if defined(USE_ROCM) -// C10_WARP_SIZE is only allowed for device code. -// Host code _must_ use at::cuda::warp_size() -// HIP header used to define warpSize as a constexpr that was either 32 or 64 -// depending on the target device, and then always set it to 64 for host code. -// Host pass of HIP compiler needs C10_WARP_SIZE defined to _something_ so we -// set it to something unreasonable to trigger obvious host code errors. -#if defined(__HIP_DEVICE_COMPILE__) -#if defined(__GFX9__) -static constexpr int C10_WARP_SIZE = 64; -#else // __GFX9__ -static constexpr int C10_WARP_SIZE = 32; -#endif // __GFX9__ -#else -static constexpr int C10_WARP_SIZE = 1; -#endif // __HIP_DEVICE_COMPILE__ -#else -#define C10_WARP_SIZE 32 -#endif - -#if defined(_MSC_VER) && _MSC_VER <= 1900 -#define __func__ __FUNCTION__ -#endif - -// CUDA_KERNEL_ASSERT checks the assertion -// even when NDEBUG is defined. This is useful for important assertions in CUDA -// code that would otherwise be suppressed when building Release. -#if defined(__ANDROID__) || defined(__APPLE__) || defined(__FreeBSD__) -// Those platforms do not support assert() -#define CUDA_KERNEL_ASSERT(cond) -#define CUDA_KERNEL_ASSERT_MSG(cond, msg) -#define SYCL_KERNEL_ASSERT(cond) -#elif defined(_MSC_VER) -#if defined(NDEBUG) -extern "C" { -C10_IMPORT -#if defined(__SYCL_DEVICE_ONLY__) -extern SYCL_EXTERNAL void _wassert( - const wchar_t* wexpr, - const wchar_t* wfile, - unsigned line); -#else -#if defined(__CUDA_ARCH__) -__host__ __device__ -#endif // __CUDA_ARCH__ - void - _wassert(wchar_t const* _Message, wchar_t const* _File, unsigned _Line); -#endif // __SYCL_DEVICE_ONLY__ -} -#endif // NDEBUG -#define CUDA_KERNEL_ASSERT(cond) \ - if (C10_UNLIKELY(!(cond))) { \ - (void)(_wassert( \ - _CRT_WIDE(#cond), \ - _CRT_WIDE(__FILE__), \ - static_cast(__LINE__)), \ - 0); \ - } -// TODO: This doesn't assert the message because I (chilli) couldn't figure out -// a nice way to convert a char* to a wchar_t* -#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ - if (C10_UNLIKELY(!(cond))) { \ - (void)(_wassert( \ - _CRT_WIDE(#cond), \ - _CRT_WIDE(__FILE__), \ - static_cast(__LINE__)), \ - 0); \ - } -#define SYCL_KERNEL_ASSERT(cond) \ - if (C10_UNLIKELY(!(cond))) { \ - (void)(_wassert( \ - _CRT_WIDE(#cond), \ - _CRT_WIDE(__FILE__), \ - static_cast(__LINE__)), \ - 0); \ - } -#else // __APPLE__, _MSC_VER -#if defined(NDEBUG) -extern "C" { -#if defined(__SYCL_DEVICE_ONLY__) -extern SYCL_EXTERNAL void __assert_fail( - const char* expr, - const char* file, - unsigned int line, - const char* func); -#else // __SYCL_DEVICE_ONLY__ -#if (defined(__CUDA_ARCH__) && !(defined(__clang__) && defined(__CUDA__))) -// CUDA supports __assert_fail function which are common for both device -// and host side code. -__host__ __device__ -#endif - - // This forward declaration matching the declaration of __assert_fail - // exactly how it is in glibc in case parts of the program are compiled with - // different NDEBUG settings. Otherwise we might get 'ambiguous declaration' - // error. Note: On ROCm - this declaration serves for host side compilation. - void - __assert_fail( - const char* assertion, - const char* file, - unsigned int line, - const char* function) noexcept __attribute__((__noreturn__)); - -#endif // __SYCL_DEVICE_ONLY__ -} -#endif // NDEBUG -// ROCm disables kernel assert by default for performance considerations. -// Though ROCm supports __assert_fail, it uses kernel printf which has -// a non-negligible performance impact even if the assert condition is -// never triggered. We choose to use abort() instead which will still -// terminate the application but without a more useful error message. -#if !defined(C10_USE_ROCM_KERNEL_ASSERT) and defined(USE_ROCM) -#define CUDA_KERNEL_ASSERT(cond) \ - if C10_UNLIKELY (!(cond)) { \ - abort(); \ - } -#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ - if C10_UNLIKELY (!(cond)) { \ - abort(); \ - } -#define SYCL_KERNEL_ASSERT(cond) \ - if C10_UNLIKELY (!(cond)) { \ - abort(); \ - } -#else -#define CUDA_KERNEL_ASSERT(cond) \ - if (C10_UNLIKELY(!(cond))) { \ - __assert_fail( \ - #cond, __FILE__, static_cast(__LINE__), __func__); \ - } -#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ - if (C10_UNLIKELY(!(cond))) { \ - __assert_fail( \ - msg, __FILE__, static_cast(__LINE__), __func__); \ - } -#define SYCL_KERNEL_ASSERT(cond) \ - if (C10_UNLIKELY(!(cond))) { \ - __assert_fail( \ - #cond, __FILE__, static_cast(__LINE__), __func__); \ - } -#endif // C10_USE_ROCM_KERNEL_ASSERT and USE_ROCM -#endif // __APPLE__ - -#ifdef __APPLE__ -#include -#endif - -#if defined(__ANDROID__) -#define C10_ANDROID 1 -#define C10_MOBILE 1 -#elif ( \ - defined(__APPLE__) && \ - (TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE)) -#define C10_IOS 1 -#define C10_MOBILE 1 -#endif // ANDROID / IOS - -#if defined(C10_MOBILE) && C10_MOBILE -#define C10_ALWAYS_INLINE_UNLESS_MOBILE inline -#else -#define C10_ALWAYS_INLINE_UNLESS_MOBILE C10_ALWAYS_INLINE -#endif - -#if !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) -#define CONSTEXPR_EXCEPT_WIN_CUDA constexpr -#define C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA constexpr - -#define STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(field, val) \ - static constexpr const char field[] = val; -#define STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(cls, field, val) -#endif // !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) - -#ifndef HAS_DEMANGLE -#if defined(__ANDROID__) || defined(_WIN32) || defined(__EMSCRIPTEN__) -#define HAS_DEMANGLE 0 -#elif defined(__APPLE__) && \ - (TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE) -#define HAS_DEMANGLE 0 -#else -#define HAS_DEMANGLE 1 -#endif -#endif // HAS_DEMANGLE - -#define _C10_PRAGMA__(string) _Pragma(#string) -#define _C10_PRAGMA_(string) _C10_PRAGMA__(string) - -#ifdef __clang__ -#define C10_CLANG_DIAGNOSTIC_PUSH() _Pragma("clang diagnostic push") -#define C10_CLANG_DIAGNOSTIC_POP() _Pragma("clang diagnostic pop") -#define C10_CLANG_DIAGNOSTIC_IGNORE(flag) \ - _C10_PRAGMA_(clang diagnostic ignored flag) -#define C10_CLANG_HAS_WARNING(flag) __has_warning(flag) -#else -#define C10_CLANG_DIAGNOSTIC_PUSH() -#define C10_CLANG_DIAGNOSTIC_POP() -#define C10_CLANG_DIAGNOSTIC_IGNORE(flag) -#define C10_CLANG_HAS_WARNING(flag) 0 -#endif - -#ifdef __clang__ - -#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) \ - _C10_PRAGMA_(clang diagnostic push) \ - _C10_PRAGMA_(clang diagnostic ignored "-Wunknown-warning-option") \ - _C10_PRAGMA_(clang diagnostic ignored warning) - -#define C10_DIAGNOSTIC_POP() _C10_PRAGMA_(clang diagnostic pop) - -#elif __GNUC__ - -#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) \ - _C10_PRAGMA_(GCC diagnostic push) \ - _C10_PRAGMA_(GCC diagnostic ignored "-Wpragmas") \ - _C10_PRAGMA_(GCC diagnostic ignored warning) - -#define C10_DIAGNOSTIC_POP() _C10_PRAGMA_(GCC diagnostic pop) - -#else - -#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) -#define C10_DIAGNOSTIC_POP() - -#endif - -// This macro is used to find older C++ compilers -// that don't support move optimization for return values. - -#if (defined(__GNUC__) && __GNUC__ < 13) || \ - (defined(__clang_major__) && __clang_major__ < 13) -#define C10_RETURN_MOVE_IF_OLD_COMPILER 1 -#else -#define C10_RETURN_MOVE_IF_OLD_COMPILER 0 -#endif - -#endif // C10_MACROS_MACROS_H_ +#include diff --git a/c10/macros/build.bzl b/c10/macros/build.bzl index 129b2b1e0570..d5809d36687d 100644 --- a/c10/macros/build.bzl +++ b/c10/macros/build.bzl @@ -1,13 +1,13 @@ def define_targets(rules): rules.cc_library( name = "macros", - srcs = [":cmake_macros_h"], hdrs = [ "Macros.h", # Despite the documentation in Macros.h, Export.h is included # directly by many downstream files. Thus, we declare it as a # public header in this file. "Export.h", + "cmake_macros.h", ], linkstatic = True, local_defines = ["C10_BUILD_MAIN_LIB"], @@ -17,22 +17,6 @@ def define_targets(rules): ], ) - rules.cmake_configure_file( - name = "cmake_macros_h", - src = "cmake_macros.h.in", - out = "cmake_macros.h", - definitions = [ - "C10_BUILD_SHARED_LIBS", - "C10_USE_MSVC_STATIC_RUNTIME", - ] + rules.select({ - "//c10:using_gflags": ["C10_USE_GFLAGS"], - "//conditions:default": [], - }) + rules.select({ - "//c10:using_glog": ["C10_USE_GLOG"], - "//conditions:default": [], - }), - ) - rules.filegroup( name = "headers", srcs = rules.glob( diff --git a/c10/macros/cmake_macros.h b/c10/macros/cmake_macros.h new file mode 100644 index 000000000000..4358f6906c97 --- /dev/null +++ b/c10/macros/cmake_macros.h @@ -0,0 +1,5 @@ +// This file exists for backwards compatibility and has been moved to +// torch/headeronly/macros/cmake_macros.h.in. No end user library should be +// including this file directly anyway (cuz they should be including +// Macros.h instead). +#include diff --git a/c10/ovrsource_defs.bzl b/c10/ovrsource_defs.bzl index 4abf8b0014de..aafe5a4de8c4 100644 --- a/c10/ovrsource_defs.bzl +++ b/c10/ovrsource_defs.bzl @@ -73,8 +73,7 @@ def define_c10_ovrsource(name, is_mobile): ], }), exported_deps = [ - "//xplat/caffe2/torch/headeronly:torch_headeronly", - ":ovrsource_c10_cmake_macros.h", + "//xplat/caffe2/torch/headeronly:torch_headeronly_ovrsource", "//arvr/third-party/gflags:gflags", "//third-party/cpuinfo:cpuinfo", "//third-party/fmt:fmt", @@ -83,55 +82,6 @@ def define_c10_ovrsource(name, is_mobile): ) def define_ovrsource_targets(): - common_c10_cmake_defines = [ - ("#cmakedefine C10_BUILD_SHARED_LIBS", ""), - ("#cmakedefine C10_USE_NUMA", ""), - ("#cmakedefine C10_USE_MSVC_STATIC_RUNTIME", ""), - ("#cmakedefine C10_USE_ROCM_KERNEL_ASSERT", ""), - ] - - mobile_c10_cmake_defines = [ - ("#cmakedefine C10_USE_GLOG", ""), - ("#cmakedefine C10_USE_GFLAGS", ""), - ] - - non_mobile_c10_cmake_defines = [ - ("#cmakedefine C10_USE_GLOG", "#define C10_USE_GLOG 1"), - ("#cmakedefine C10_USE_GFLAGS", "#define C10_USE_GFLAGS 1"), - ] - - gen_cmake_header( - src = "macros/cmake_macros.h.in", - defines = common_c10_cmake_defines + mobile_c10_cmake_defines, - header = "c10/macros/cmake_macros.h", - prefix = "ovrsource_c10_mobile_", - ) - - gen_cmake_header( - src = "macros/cmake_macros.h.in", - defines = common_c10_cmake_defines + non_mobile_c10_cmake_defines, - header = "c10/macros/cmake_macros.h", - prefix = "ovrsource_c10_non_mobile_", - ) - - oxx_static_library( - name = "ovrsource_c10_cmake_macros.h", - compatible_with = [ - "ovr_config//os:android", - "ovr_config//os:iphoneos", - "ovr_config//os:linux", - "ovr_config//os:macos", - "ovr_config//os:windows", - ], - deps = select({ - "ovr_config//os:android": [":ovrsource_c10_mobile_cmake_macros.h"], - "ovr_config//os:iphoneos": [":ovrsource_c10_mobile_cmake_macros.h"], - "ovr_config//os:linux": [":ovrsource_c10_non_mobile_cmake_macros.h"], - "ovr_config//os:macos": [":ovrsource_c10_non_mobile_cmake_macros.h"], - "ovr_config//os:windows": [":ovrsource_c10_non_mobile_cmake_macros.h"], - }), - ) - c10_cuda_macros = gen_cmake_header( src = "cuda/impl/cuda_cmake_macros.h.in", defines = [ diff --git a/c10/test/core/AllocatorConfig_test.cpp b/c10/test/core/AllocatorConfig_test.cpp deleted file mode 100644 index c2c0e6261d7b..000000000000 --- a/c10/test/core/AllocatorConfig_test.cpp +++ /dev/null @@ -1,129 +0,0 @@ -#include - -#include - -using namespace c10::CachingAllocator; -constexpr size_t kMB = 1024 * 1024ul; - -struct ExtendedAllocatorConfig { - static ExtendedAllocatorConfig& instance() { - static ExtendedAllocatorConfig instance; - return instance; - } - - // Returns the device-specific option value in bytes. - static size_t device_specific_option() { - return instance().device_specific_option_; - } - - static const std::unordered_set& getKeys() { - return instance().keys_; - } - - void parseArgs(const std::string& env) { - // Parse device-specific options from the environment variable - ConfigTokenizer tokenizer(env); - for (size_t i = 0; i < tokenizer.size(); i++) { - const auto& key = tokenizer[i]; - if (key == "device_specific_option_mb") { - tokenizer.checkToken(++i, ":"); - device_specific_option_ = tokenizer.toSizeT(++i) * kMB; - } else { - i = tokenizer.skipKey(i); - } - - if (i + 1 < tokenizer.size()) { - tokenizer.checkToken(++i, ","); - } - } - } - - private: - // Device-specific option, e.g., memory limit for a specific device. - std::atomic device_specific_option_{0}; - std::unordered_set keys_{"device_specific_option_mb"}; -}; - -REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(ExtendedAllocatorConfig) - -TEST(AllocatorConfigTest, allocator_config_test) { - std::string env = - "max_split_size_mb:40," - "max_non_split_rounding_mb:30," - "garbage_collection_threshold:0.5," - "roundup_power2_divisions:[64:8,128:2,256:4,512:2,1024:4,>:1]," - "expandable_segments:True," - "pinned_use_background_threads:True," - "device_specific_option_mb:64"; - c10::CachingAllocator::setAllocatorSettings(env); - EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env); - EXPECT_EQ(AcceleratorAllocatorConfig::max_split_size(), 40 * kMB); - EXPECT_EQ( - AcceleratorAllocatorConfig::max_non_split_rounding_size(), 30 * kMB); - EXPECT_EQ(AcceleratorAllocatorConfig::garbage_collection_threshold(), 0.5); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(32 * kMB), 8); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(64 * kMB), 8); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(128 * kMB), 2); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(256 * kMB), 4); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(512 * kMB), 2); - EXPECT_EQ( - AcceleratorAllocatorConfig::roundup_power2_divisions(1024 * kMB), 4); - EXPECT_EQ( - AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 1); - EXPECT_EQ( - AcceleratorAllocatorConfig::roundup_power2_divisions(4096 * kMB), 1); - EXPECT_EQ( - AcceleratorAllocatorConfig::roundup_power2_divisions(8192 * kMB), 1); - EXPECT_EQ(AcceleratorAllocatorConfig::use_expandable_segments(), true); - EXPECT_EQ(AcceleratorAllocatorConfig::pinned_use_background_threads(), true); - EXPECT_EQ(ExtendedAllocatorConfig::device_specific_option(), 64 * kMB); - - env = - "max_split_size_mb:20," - "max_non_split_rounding_mb:40," - "garbage_collection_threshold:0.8"; - c10::CachingAllocator::setAllocatorSettings(env); - EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env); - EXPECT_EQ(AcceleratorAllocatorConfig::max_split_size(), 20 * kMB); - EXPECT_EQ( - AcceleratorAllocatorConfig::max_non_split_rounding_size(), 40 * kMB); - EXPECT_EQ(AcceleratorAllocatorConfig::garbage_collection_threshold(), 0.8); - - // roundup_power2_divisions knob array syntax - env = "roundup_power2_divisions:[128:8,256:16,512:1,2048:8,>:2]"; - c10::CachingAllocator::setAllocatorSettings(env); - EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(64 * kMB), 8); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(128 * kMB), 8); - EXPECT_EQ( - AcceleratorAllocatorConfig::roundup_power2_divisions(256 * kMB), 16); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(512 * kMB), 1); - EXPECT_EQ( - AcceleratorAllocatorConfig::roundup_power2_divisions(1024 * kMB), 0); - EXPECT_EQ( - AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 8); - EXPECT_EQ( - AcceleratorAllocatorConfig::roundup_power2_divisions(4096 * kMB), 2); - - // roundup_power2_divisions single value syntax for backward compatibility - env = "roundup_power2_divisions:4"; - c10::CachingAllocator::setAllocatorSettings(env); - EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(64 * kMB), 4); - EXPECT_EQ(AcceleratorAllocatorConfig::roundup_power2_divisions(256 * kMB), 4); - EXPECT_EQ( - AcceleratorAllocatorConfig::roundup_power2_divisions(2048 * kMB), 4); - - env = "expandable_segments:False,"; - c10::CachingAllocator::setAllocatorSettings(env); - EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env); - EXPECT_EQ(AcceleratorAllocatorConfig::use_expandable_segments(), false); - - env = "pinned_use_background_threads:False"; - c10::CachingAllocator::setAllocatorSettings(env); - EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env); - EXPECT_EQ(AcceleratorAllocatorConfig::pinned_use_background_threads(), false); - - env = "foo:123,bar:456"; - ASSERT_THROW(c10::CachingAllocator::setAllocatorSettings(env), c10::Error); -} diff --git a/c10/util/Exception.h b/c10/util/Exception.h index 3ff3396f5f1b..545cef535138 100644 --- a/c10/util/Exception.h +++ b/c10/util/Exception.h @@ -267,6 +267,13 @@ class C10_API NotImplementedError : public Error { using Error::Error; }; +// Used in ATen for buffer-related errors, e.g. trying to create a DLPack of +// an unsupported device. These turn into BufferError when they cross to +// Python. +class C10_API BufferError : public Error { + using Error::Error; +}; + // Used in ATen for non finite indices. These turn into // ExitException when they cross to Python. class C10_API EnforceFiniteError : public Error { @@ -365,26 +372,7 @@ C10_API std::string GetExceptionString(const std::exception& e); // https://stackoverflow.com/questions/5134523/msvc-doesnt-expand-va-args-correctly #define C10_EXPAND_MSVC_WORKAROUND(x) x -// On nvcc, C10_UNLIKELY thwarts missing return statement analysis. In cases -// where the unlikely expression may be a constant, use this macro to ensure -// return statement analysis keeps working (at the cost of not getting the -// likely/unlikely annotation on nvcc). -// https://github.com/pytorch/pytorch/issues/21418 -// -// Currently, this is only used in the error reporting macros below. If you -// want to use it more generally, move me to Macros.h -// -// TODO: Brian Vaughan observed that we might be able to get this to work on -// nvcc by writing some sort of C++ overload that distinguishes constexpr inputs -// from non-constexpr. Since there isn't any evidence that losing C10_UNLIKELY -// in nvcc is causing us perf problems, this is not yet implemented, but this -// might be an interesting piece of C++ code for an intrepid bootcamper to -// write. -#if defined(__CUDACC__) -#define C10_UNLIKELY_OR_CONST(e) e -#else -#define C10_UNLIKELY_OR_CONST(e) C10_UNLIKELY(e) -#endif +#include // ---------------------------------------------------------------------------- // Error reporting macros @@ -654,6 +642,10 @@ namespace c10::detail { #define TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \ TORCH_CHECK_WITH_MSG(NotImplementedError, cond, "TYPE", __VA_ARGS__) +// Like TORCH_CHECK, but raises BufferError instead of Errors. +#define TORCH_CHECK_BUFFER(cond, ...) \ + TORCH_CHECK_WITH_MSG(BufferError, cond, "TYPE", __VA_ARGS__) + #define TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(cond, ...) \ TORCH_CHECK_WITH_MSG( \ ErrorAlwaysShowCppStacktrace, cond, "TYPE", ##__VA_ARGS__) diff --git a/c10/util/TypeSafeSignMath.h b/c10/util/TypeSafeSignMath.h index 2853ff48d183..58c050678302 100644 --- a/c10/util/TypeSafeSignMath.h +++ b/c10/util/TypeSafeSignMath.h @@ -79,7 +79,7 @@ template inline constexpr bool greater_than_max(const T& x) { constexpr bool can_overflow = std::numeric_limits::digits > std::numeric_limits::digits; - return can_overflow && x > std::numeric_limits::max(); + return can_overflow && x > (std::numeric_limits::max)(); } #ifdef __GNUC__ diff --git a/c10/util/WaitCounter.h b/c10/util/WaitCounter.h index 193740cb10db..c87c2e3293e5 100644 --- a/c10/util/WaitCounter.h +++ b/c10/util/WaitCounter.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index afae32d92a4b..543b48f08113 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -21,6 +20,8 @@ constexpr size_t kMinBlockSize = 512; constexpr size_t kSmallSize = 1048576; // "small" allocations are packed in 2 MiB blocks constexpr size_t kSmallBuffer = 2097152; +// "large" allocations may be packed in 20 MiB blocks +constexpr size_t kLargeBuffer = 20971520; // allocations between 1 and 10 MiB may use kLargeBuffer constexpr size_t kMinLargeAlloc = 10485760; // round up large allocations to 2 MiB diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 6f6feac4b1ce..1edcb36e94f9 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1001,7 +1001,7 @@ elseif(USE_CUDA) # 3. Let CMake find it in the default system paths, e.g. /usr/local. find_library(NVSHMEM_HOST_LIB # In pip install case, the lib suffix is `.so.3` instead of `.so` - NAMES nvshmem_host nvshmem_host.so.3 + NAMES nvshmem_host libnvshmem_host.so.3 NAMES_PER_DIR HINTS $ENV{NVSHMEM_HOME} ${NVSHMEM_PY_DIR} PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64 DOC "The location of NVSHMEM host library.") @@ -1027,24 +1027,24 @@ elseif(USE_CUDA) # Linking with nvshmem requires the source binary to be built with -rdc # which is not viable for libtorch_cuda. So we isolate the linking of - # nvshmem in nvshmem_extension. - add_library(nvshmem_extension SHARED + # nvshmem in torch_nvshmem. + add_library(torch_nvshmem SHARED "${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/utils.cpp" "${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu" "${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu" "${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp" ) - set_target_properties(nvshmem_extension PROPERTIES CUDA_SEPARABLE_COMPILATION ON) - target_compile_options(nvshmem_extension PRIVATE $<$:-rdc=true>) - target_compile_options(nvshmem_extension PRIVATE "-U__CUDA_NO_HALF_OPERATORS__") - target_link_libraries(nvshmem_extension PRIVATE + set_target_properties(torch_nvshmem PROPERTIES CUDA_SEPARABLE_COMPILATION ON) + target_compile_options(torch_nvshmem PRIVATE $<$:-rdc=true>) + target_compile_options(torch_nvshmem PRIVATE "-U__CUDA_NO_HALF_OPERATORS__") + target_link_libraries(torch_nvshmem PRIVATE ${NVSHMEM_HOST_LIB} ${NVSHMEM_DEVICE_LIB} ) target_compile_definitions(torch_cuda PUBLIC USE_NVSHMEM) - target_compile_definitions(nvshmem_extension PUBLIC USE_NVSHMEM) - target_link_libraries(torch_cuda PRIVATE nvshmem_extension) - install(TARGETS nvshmem_extension EXPORT Caffe2Targets DESTINATION lib) + target_compile_definitions(torch_nvshmem PUBLIC USE_NVSHMEM) + target_link_libraries(torch_cuda PRIVATE torch_nvshmem) + install(TARGETS torch_nvshmem EXPORT Caffe2Targets DESTINATION "${TORCH_INSTALL_LIB_DIR}") else() message(STATUS "NVSHMEM not found, not building with NVSHMEM support.") endif() diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index 8004b0f400a8..8b380d24f6c8 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -24,7 +24,7 @@ if(NOT __AOTRITON_INCLUDED) set(__AOTRITON_SHA256_LIST "861cd9f7479eec943933c27cb86920247e5b5dd139bc7c1376c81808abb7d7fe" # rocm6.3 "acea7d811a2d3bbe718b6e07fc2a9f739e49eecd60b4b6a36fcb3fe8edf85d78" # rocm6.4 - "7e29c325d5bd33ba896ddb106f5d4fc7d715274dca7fe937f724fffa82017838" # rocm7.0 + "1e9b3dddf0c7fc07131c6f0f5266129e83ce2331f459fa2be8c63f4ae91b0f5b" # rocm7.0 ) set(__AOTRITON_Z "gz") diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index a80365f353df..3c2ec74f14d1 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -70,7 +70,6 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_CPP_CODE_COVERAGE : ${USE_CPP_CODE_COVERAGE}") message(STATUS " USE_CUDA : ${USE_CUDA}") if(${USE_CUDA}) - message(STATUS " Split CUDA : ${BUILD_SPLIT_CUDA}") message(STATUS " CUDA static link : ${CAFFE2_STATIC_LINK_CUDA}") message(STATUS " USE_CUDNN : ${USE_CUDNN}") message(STATUS " USE_CUSPARSELT : ${USE_CUSPARSELT}") diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index cae0ca62f236..132f9670ff34 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -93,19 +93,16 @@ if(HIP_FOUND) # hip (lower-case) package. Both are probed above and will be in # ROCM_INCLUDE_DIRS if available. find_file(ROCM_VERSION_HEADER_PATH - NAMES rocm-core/rocm_version.h + NAMES rocm-core/rocm_version.h hip/hip_version.h NO_DEFAULT_PATH PATHS ${ROCM_INCLUDE_DIRS} ) - set(ROCM_LIB_NAME "ROCM") - if(NOT ROCM_VERSION_HEADER_PATH) - find_file(ROCM_VERSION_HEADER_PATH - NAMES hip/hip_version.h - NO_DEFAULT_PATH - PATHS ${ROCM_INCLUDE_DIRS} - ) + if(ROCM_VERSION_HEADER_PATH MATCHES "rocm-core/rocm_version.h$") + set(ROCM_LIB_NAME "ROCM") + else() set(ROCM_LIB_NAME "HIP") endif() + if(NOT ROCM_VERSION_HEADER_PATH) message(FATAL_ERROR "Could not find hip/hip_version.h or rocm-core/rocm_version.h in ${ROCM_INCLUDE_DIRS}") endif() diff --git a/cmake/public/utils.cmake b/cmake/public/utils.cmake index 0f34603cf023..032db8a8ab5c 100644 --- a/cmake/public/utils.cmake +++ b/cmake/public/utils.cmake @@ -394,7 +394,7 @@ function(torch_compile_options libname) list(APPEND private_compile_options -Wredundant-move) endif() if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") - list(APPEND private_compile_options -Wextra-semi -Wno-error=extra-semi -Wmove) + list(APPEND private_compile_options -Wextra-semi -Wmove) else() list(APPEND private_compile_options # Considered to be flaky. See the discussion at diff --git a/cmake/public/xpu.cmake b/cmake/public/xpu.cmake index be083cb93af1..b39e31d0ade8 100644 --- a/cmake/public/xpu.cmake +++ b/cmake/public/xpu.cmake @@ -11,6 +11,7 @@ set(XPU_HOST_CXX_FLAGS) find_package(SYCLToolkit REQUIRED) if(NOT SYCL_FOUND) set(PYTORCH_FOUND_XPU FALSE) + # Exit early to avoid populating XPU_HOST_CXX_FLAGS. return() endif() set(PYTORCH_FOUND_XPU TRUE) @@ -36,6 +37,8 @@ torch_xpu_get_arch_list(XPU_ARCH_FLAGS) # propagate to torch-xpu-ops set(TORCH_XPU_ARCH_LIST ${XPU_ARCH_FLAGS}) +# Ensure USE_XPU is enabled. +string(APPEND XPU_HOST_CXX_FLAGS " -DUSE_XPU") string(APPEND XPU_HOST_CXX_FLAGS " -DSYCL_COMPILER_VERSION=${SYCL_COMPILER_VERSION}") if(DEFINED ENV{XPU_ENABLE_KINETO}) diff --git a/docs/source/backends.md b/docs/source/backends.md index 41869ba9b77b..6b8cc8bd7072 100644 --- a/docs/source/backends.md +++ b/docs/source/backends.md @@ -54,7 +54,7 @@ These backends include: .. attribute:: allow_tf32 A :class:`bool` that controls whether TensorFloat-32 tensor cores may be used in matrix - multiplications on Ampere or newer GPUs. See :ref:`tf32_on_ampere`. + multiplications on Ampere or newer GPUs. allow_tf32 is going to be deprecated. See :ref:`tf32_on_ampere`. ``` ```{eval-rst} @@ -193,7 +193,7 @@ These backends include: .. attribute:: allow_tf32 A :class:`bool` that controls where TensorFloat-32 tensor cores may be used in cuDNN - convolutions on Ampere or newer GPUs. See :ref:`tf32_on_ampere`. + convolutions on Ampere or newer GPUs. allow_tf32 is going to be deprecated. See :ref:`tf32_on_ampere`. ``` ```{eval-rst} diff --git a/docs/source/conf.py b/docs/source/conf.py index acb2b088af72..34d8e9876b17 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -62,7 +62,7 @@ "sphinxcontrib.katex", "sphinx_copybutton", "sphinx_design", - "myst_parser", + "myst_nb", "sphinx.ext.linkcode", "sphinxcontrib.mermaid", "sphinx_sitemap", @@ -82,6 +82,10 @@ ] sitemap_url_scheme = "{link}" +html_additional_pages = { + "404": "404.html", +} + # build the templated autosummary files autosummary_generate = True numpydoc_show_class_members = False @@ -1837,31 +1841,9 @@ "check_export_model_diff", "verify", "verify_aten_graph", - # torch.optim.adadelta - "adadelta", - # torch.optim.adagrad - "adagrad", - # torch.optim.adam - "adam", - # torch.optim.adamax - "adamax", - # torch.optim.adamw - "adamw", - # torch.optim.asgd - "asgd", - # torch.optim.nadam - "nadam", # torch.optim.optimizer "register_optimizer_step_post_hook", "register_optimizer_step_pre_hook", - # torch.optim.radam - "radam", - # torch.optim.rmsprop - "rmsprop", - # torch.optim.rprop - "rprop", - # torch.optim.sgd - "sgd", # torch.optim.swa_utils "get_ema_avg_fn", "get_ema_multi_avg_fn", @@ -3109,12 +3091,6 @@ # torch.onnx.verification "OnnxBackend", "OnnxTestCaseRepro", - # torch.optim.adadelta - "Adadelta", - # torch.optim.adagrad - "Adagrad", - # torch.optim.adam - "Adam", # torch.optim.adamax "Adamax", # torch.optim.adamw @@ -3140,23 +3116,8 @@ "ReduceLROnPlateau", "SequentialLR", "StepLR", - # torch.optim.nadam - "NAdam", # torch.optim.optimizer "Optimizer", - # torch.optim.radam - "RAdam", - # torch.optim.rmsprop - "RMSprop", - # torch.optim.rprop - "Rprop", - # torch.optim.sgd - "SGD", - # torch.optim.sparse_adam - "SparseAdam", - # torch.optim.swa_utils - "AveragedModel", - "SWALR", # torch.overrides "BaseTorchFunctionMode", "TorchFunctionMode", diff --git a/docs/source/cpp_index.rst b/docs/source/cpp_index.rst index 23302286f0e3..37571b9c60bc 100644 --- a/docs/source/cpp_index.rst +++ b/docs/source/cpp_index.rst @@ -7,20 +7,6 @@ C++ PyTorch provides several features for working with C++, and it’s best to choose from them based on your needs. At a high level, the following support is available: -TorchScript C++ API --------------------- -`TorchScript `__ allows PyTorch models defined in Python to be serialized and then loaded and run in C++ capturing the model code via compilation or tracing its execution. You can learn more in the `Loading a TorchScript Model in C++ tutorial `__. This means you can define your models in Python as much as possible, but subsequently export them via TorchScript for doing no-Python execution in production or embedded environments. The TorchScript C++ API is used to interact with these models and the TorchScript execution engine, including: - -* Loading serialized TorchScript models saved from Python -* Doing simple model modifications if needed (e.g. pulling out submodules) -* Constructing the input and doing preprocessing using C++ Tensor API - -Extending PyTorch and TorchScript with C++ Extensions ------------------------------------------------------- -TorchScript can be augmented with user-supplied code through custom operators and custom classes. -Once registered with TorchScript, these operators and classes can be invoked in TorchScript code run from -Python or from C++ as part of a serialized TorchScript model. The `Extending TorchScript with Custom C++ Operators `__ tutorial walks through interfacing TorchScript with OpenCV. In addition to wrapping a function call with a custom operator, C++ classes and structs can be bound into TorchScript through a pybind11-like interface which is explained in the `Extending TorchScript with Custom C++ Classes `__ tutorial. - Tensor and Autograd in C++ --------------------------- Most of the tensor and autograd operations in PyTorch Python API are also available in the C++ API. These include: @@ -31,9 +17,7 @@ Most of the tensor and autograd operations in PyTorch Python API are also availa Authoring Models in C++ ------------------------ -The "author in TorchScript, infer in C++" workflow requires model authoring to be done in TorchScript. -However, there might be cases where the model has to be authored in C++ (e.g. in workflows where a Python -component is undesirable). To serve such use cases, we provide the full capability of authoring and training a neural net model purely in C++, with familiar components such as ``torch::nn`` / ``torch::nn::functional`` / ``torch::optim`` that closely resemble the Python API. +We provide the full capability of authoring and training a neural net model purely in C++, with familiar components such as ``torch::nn`` / ``torch::nn::functional`` / ``torch::optim`` that closely resemble the Python API. * For an overview of the PyTorch C++ model authoring and training API, please see: https://pytorch.org/cppdocs/frontend.html * For a detailed tutorial on how to use the API, please see: https://pytorch.org/tutorials/advanced/cpp_frontend.html diff --git a/docs/source/distributed.md b/docs/source/distributed.md index 95820f8244c5..9762e79c7ea3 100644 --- a/docs/source/distributed.md +++ b/docs/source/distributed.md @@ -20,39 +20,41 @@ for a brief introduction to all features related to distributed training. ## Backends -`torch.distributed` supports three built-in backends, each with +`torch.distributed` supports four built-in backends, each with different capabilities. The table below shows which functions are available -for use with CPU / CUDA tensors. +for use with a CPU or GPU for each backend. For NCCL, GPU refers to CUDA GPU +while for XCCL to XPU GPU. + MPI supports CUDA only if the implementation used to build PyTorch supports it. ```{eval-rst} -+----------------+-----------+-----------+-----------+ -| Backend | ``gloo`` | ``mpi`` | ``nccl`` | -+----------------+-----+-----+-----+-----+-----+-----+ -| Device | CPU | GPU | CPU | GPU | CPU | GPU | -+================+=====+=====+=====+=====+=====+=====+ -| send | ✓ | ✘ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| recv | ✓ | ✘ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| broadcast | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| all_reduce | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| reduce | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| all_gather | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| gather | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| scatter | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| reduce_scatter | ✓ | ✓ | ✘ | ✘ | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| all_to_all | ✓ | ✓ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ -| barrier | ✓ | ✘ | ✓ | ? | ✘ | ✓ | -+----------------+-----+-----+-----+-----+-----+-----+ ++----------------+-----------+-----------+-----------+-----------+ +| Backend | ``gloo`` | ``mpi`` | ``nccl`` | ``xccl`` | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| Device | CPU | GPU | CPU | GPU | CPU | GPU | CPU | GPU | ++================+=====+=====+=====+=====+=====+=====+=====+=====+ +| send | ✓ | ✘ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| recv | ✓ | ✘ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| broadcast | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| all_reduce | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| reduce | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| all_gather | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| gather | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| scatter | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| reduce_scatter | ✓ | ✓ | ✘ | ✘ | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| all_to_all | ✓ | ✓ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ +| barrier | ✓ | ✘ | ✓ | ? | ✘ | ✓ | ✘ | ✓ | ++----------------+-----+-----+-----+-----+-----+-----+-----+-----+ ``` ### Backends that come with PyTorch @@ -81,8 +83,9 @@ In the past, we were often asked: "which backend should I use?". - Rule of thumb - - Use the NCCL backend for distributed **GPU** training - - Use the Gloo backend for distributed **CPU** training. + - Use the NCCL backend for distributed training with CUDA **GPU**. + - Use the XCCL backend for distributed training with XPU **GPU**. + - Use the Gloo backend for distributed training with **CPU**. - GPU hosts with InfiniBand interconnect diff --git a/docs/source/export.md b/docs/source/export.md index 9d57614a14ad..fcebcc6d4962 100644 --- a/docs/source/export.md +++ b/docs/source/export.md @@ -1,12 +1,17 @@ +--- +file_format: mystnb +kernelspec: + name: python3 +mystnb: + execution_timeout: 30 + execution_show_tb: True + merge_streams: True +--- + (torch.export)= # torch.export -:::{warning} -This feature is a prototype under active development and there WILL BE -BREAKING CHANGES in the future. -::: - ## Overview {func}`torch.export.export` takes a {class}`torch.nn.Module` and produces a traced graph @@ -14,9 +19,9 @@ representing only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, which can subsequently be executed with different outputs or serialized. -```python +```{code-cell} import torch -from torch.export import export +from torch.export import export, ExportedProgram class Mod(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -26,53 +31,10 @@ class Mod(torch.nn.Module): example_args = (torch.randn(10, 10), torch.randn(10, 10)) -exported_program: torch.export.ExportedProgram = export( - Mod(), args=example_args -) +exported_program: ExportedProgram = export(Mod(), args=example_args) print(exported_program) ``` -```python -ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"): - # code: a = torch.sin(x) - sin: "f32[10, 10]" = torch.ops.aten.sin.default(x) - - # code: b = torch.cos(y) - cos: "f32[10, 10]" = torch.ops.aten.cos.default(y) - - # code: return a + b - add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos) - return (add,) - - Graph signature: - ExportGraphSignature( - input_specs=[ - InputSpec( - kind=, - arg=TensorArgument(name='x'), - target=None, - persistent=None - ), - InputSpec( - kind=, - arg=TensorArgument(name='y'), - target=None, - persistent=None - ) - ], - output_specs=[ - OutputSpec( - kind=, - arg=TensorArgument(name='add'), - target=None - ) - ] - ) - Range constraints: {} -``` - `torch.export` produces a clean intermediate representation (IR) with the following invariants. More specifications about the IR can be found {ref}`here `. @@ -130,10 +92,10 @@ level). Note that users can still use {func}`torch.fx.symbolic_trace` as a preprocessing step before `torch.export`. Compared to {func}`torch.jit.script`, `torch.export` does not capture Python -control flow or data structures, but it supports more Python language -features due to its comprehensive coverage over Python bytecodes. -The resulting graphs are simpler and only have straight line control -flow, except for explicit control flow operators. +control flow or data structures, unless using explicit {ref}`control flow operators `, +but it supports more Python language features due to its comprehensive coverage +over Python bytecodes. The resulting graphs are simpler and only have straight +line control flow, except for explicit control flow operators. Compared to {func}`torch.jit.trace`, `torch.export` is sound: it can trace code that performs integer computation on sizes and records @@ -142,16 +104,14 @@ trace is valid for other inputs. ## Exporting a PyTorch Model -### An Example - The main entrypoint is through {func}`torch.export.export`, which takes a -callable ({class}`torch.nn.Module`, function, or method) and sample inputs, and +{class}`torch.nn.Module` and sample inputs, and captures the computation graph into an {class}`torch.export.ExportedProgram`. An example: -```python +```{code-cell} import torch -from torch.export import export +from torch.export import export, ExportedProgram # Simple module for demonstration class M(torch.nn.Module): @@ -171,64 +131,13 @@ class M(torch.nn.Module): example_args = (torch.randn(1, 3, 256, 256),) example_kwargs = {"constant": torch.ones(1, 16, 256, 256)} -exported_program: torch.export.ExportedProgram = export( +exported_program: ExportedProgram = export( M(), args=example_args, kwargs=example_kwargs ) print(exported_program) -``` -```python -ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"): - # code: a = self.conv(x) - conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1]) - - # code: a.add_(constant) - add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant) - - # code: return self.maxpool(self.relu(a)) - relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_) - max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3]) - return (max_pool2d,) - -Graph signature: - ExportGraphSignature( - input_specs=[ - InputSpec( - kind=, - arg=TensorArgument(name='p_conv_weight'), - target='conv.weight', - persistent=None - ), - InputSpec( - kind=, - arg=TensorArgument(name='p_conv_bias'), - target='conv.bias', - persistent=None - ), - InputSpec( - kind=, - arg=TensorArgument(name='x'), - target=None, - persistent=None - ), - InputSpec( - kind=, - arg=TensorArgument(name='constant'), - target=None, - persistent=None - ) - ], - output_specs=[ - OutputSpec( - kind=, - arg=TensorArgument(name='max_pool2d'), - target=None - ) - ] - ) -Range constraints: {} +# To run the exported program, we can use the `module()` method +print(exported_program.module()(torch.randn(1, 3, 256, 256), constant=torch.ones(1, 16, 256, 256))) ``` Inspecting the `ExportedProgram`, we can note the following: @@ -236,191 +145,66 @@ Inspecting the `ExportedProgram`, we can note the following: - The {class}`torch.fx.Graph` contains the computation graph of the original program, along with records of the original code for easy debugging. - The graph contains only `torch.ops.aten` operators found [here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml) - and custom operators, and is fully functional, without any inplace operators - such as `torch.add_`. + and custom operators. - The parameters (weight and bias to conv) are lifted as inputs to the graph, resulting in no `get_attr` nodes in the graph, which previously existed in the result of {func}`torch.fx.symbolic_trace`. - The {class}`torch.export.ExportGraphSignature` models the input and output signature, along with specifying which inputs are parameters. - The resulting shape and dtype of tensors produced by each node in the graph is - noted. For example, the `convolution` node will result in a tensor of dtype + noted. For example, the `conv2d` node will result in a tensor of dtype `torch.float32` and shape (1, 16, 256, 256). -(non-strict-export)= - -### Non-Strict Export - -In PyTorch 2.3, we introduced a new mode of tracing called **non-strict mode**. -It's still going through hardening, so if you run into any issues, please file -them to Github with the "oncall: export" tag. +## Expressing Dynamism -In *non-strict mode*, we trace through the program using the Python interpreter. -Your code will execute exactly as it would in eager mode; the only difference is -that all Tensor objects will be replaced by ProxyTensors, which will record all -their operations into a graph. - -In *strict* mode, which is currently the default, we first trace through the -program using TorchDynamo, a bytecode analysis engine. TorchDynamo does not -actually execute your Python code. Instead, it symbolically analyzes it and -builds a graph based on the results. This analysis allows torch.export to -provide stronger guarantees about safety, but not all Python code is supported. +By default `torch.export` will trace the program assuming all input shapes are +**static**, and specializing the exported program to those dimensions. One +consequence of this is that at runtime, the program won’t work on inputs with +different shapes, even if they’re valid in eager mode. -An example of a case where one might want to use non-strict mode is if you run -into a unsupported TorchDynamo feature that might not be easily solved, and you -know the python code is not exactly needed for computation. For example: +An example: -```python -import contextlib +```{code-cell} import torch - -class ContextManager(): - def __init__(self): - self.count = 0 - def __enter__(self): - self.count += 1 - def __exit__(self, exc_type, exc_value, traceback): - self.count -= 1 +import traceback as tb class M(torch.nn.Module): - def forward(self, x): - with ContextManager(): - return x.sin() + x.cos() - -export(M(), (torch.ones(3, 3),), strict=False) # Non-strict traces successfully -export(M(), (torch.ones(3, 3),)) # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager -``` - -In this example, the first call using non-strict mode (through the -`strict=False` flag) traces successfully whereas the second call using strict -mode (default) results with a failure, where TorchDynamo is unable to support -context managers. One option is to rewrite the code (see {ref}`Limitations of torch.export `), -but seeing as the context manager does not affect the tensor -computations in the model, we can go with the non-strict mode's result. - -(training-export)= - -### Export for Training and Inference - -In PyTorch 2.5, we introduced a new API called {func}`export_for_training`. -It's still going through hardening, so if you run into any issues, please file -them to Github with the "oncall: export" tag. - -In this API, we produce the most generic IR that contains all ATen operators -(including both functional and non-functional) which can be used to train in -eager PyTorch Autograd. This API is intended for eager training use cases such as PT2 Quantization -and will soon be the default IR of torch.export.export. To read further about -the motivation behind this change, please refer to - - -When this API is combined with {func}`run_decompositions()`, you should be able to get inference IR with -any desired decomposition behavior. - -To show some examples: - -```python -class ConvBatchnorm(torch.nn.Module): - def __init__(self) -> None: + def __init__(self): super().__init__() - self.conv = torch.nn.Conv2d(1, 3, 1, 1) - self.bn = torch.nn.BatchNorm2d(3) - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - return (x,) - -mod = ConvBatchnorm() -inp = torch.randn(1, 1, 3, 3) - -ep_for_training = torch.export.export_for_training(mod, (inp,)) -print(ep_for_training) -``` - -```python -ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): - conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias) - add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1) - batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True) - return (batch_norm,) -``` - -From the above output, you can see that {func}`export_for_training` produces pretty much the same ExportedProgram -as {func}`export` except for the operators in the graph. You can see that we captured batch_norm in the most general -form. This op is non-functional and will be lowered to different ops when running inference. - -You can also go from this IR to an inference IR via {func}`run_decompositions` with arbitrary customizations. - -```python -# Lower to core aten inference IR, but keep conv2d -decomp_table = torch.export.default_decompositions() -del decomp_table[torch.ops.aten.conv2d.default] -ep_for_inference = ep_for_training.run_decompositions(decomp_table) - -print(ep_for_inference) -``` -```python -ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): - conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias) - add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1) - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05) - getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] - getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] - getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4] - return (getitem_3, getitem_4, add, getitem) -``` - -Here you can see that we kept `conv2d` op in the IR while decomposing the rest. Now the IR is a functional IR -containing core aten operators except for `conv2d`. - -You can do even more customization by directly registering your chosen decomposition behaviors. - -You can do even more customizations by directly registering custom decomp behaviour - -```python -# Lower to core aten inference IR, but customize conv2d -decomp_table = torch.export.default_decompositions() + self.branch1 = torch.nn.Sequential( + torch.nn.Linear(64, 32), torch.nn.ReLU() + ) + self.branch2 = torch.nn.Sequential( + torch.nn.Linear(128, 64), torch.nn.ReLU() + ) + self.buffer = torch.ones(32) -def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1): - return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups) + def forward(self, x1, x2): + out1 = self.branch1(x1) + out2 = self.branch2(x2) + return (out1 + self.buffer, out2) -decomp_table[torch.ops.aten.conv2d.default] = my_awesome_conv2d_function -ep_for_inference = ep_for_training.run_decompositions(decomp_table) +example_args = (torch.randn(32, 64), torch.randn(32, 128)) -print(ep_for_inference) -``` +ep = torch.export.export(M(), example_args) +print(ep) -```python -ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"): - convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1) - mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2) - add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1) - _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05) - getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] - getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] - getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; - return (getitem_3, getitem_4, add, getitem) +example_args2 = (torch.randn(64, 64), torch.randn(64, 128)) +try: + ep.module()(*example_args2) # fails +except Exception: + tb.print_exc() ``` -### Expressing Dynamism -By default `torch.export` will trace the program assuming all input shapes are -**static**, and specializing the exported program to those dimensions. However, -some dimensions, such as a batch dimension, can be dynamic and vary from run to -run. Such dimensions must be specified by using the -{func}`torch.export.Dim` API to create them and by passing them into -{func}`torch.export.export` through the `dynamic_shapes` argument. An example: +However, some dimensions, such as a batch dimension, can be dynamic and vary +from run to run. Such dimensions must be specified by using the +{func}`torch.export.Dim()` API to create them and by passing them into +{func}`torch.export.export()` through the `dynamic_shapes` argument. -```python +```{code-cell} import torch -from torch.export import Dim, export class M(torch.nn.Module): def __init__(self): @@ -442,42 +226,25 @@ class M(torch.nn.Module): example_args = (torch.randn(32, 64), torch.randn(32, 128)) # Create a dynamic batch size -batch = Dim("batch") +batch = torch.export.Dim("batch") # Specify that the first dimension of each input is that batch size dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} -exported_program: torch.export.ExportedProgram = export( +ep = torch.export.export( M(), args=example_args, dynamic_shapes=dynamic_shapes ) -print(exported_program) -``` - -```python -ExportedProgram: -class GraphModule(torch.nn.Module): - def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s0, 64]", x2: "f32[s0, 128]"): - - # code: out1 = self.branch1(x1) - linear: "f32[s0, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias) - relu: "f32[s0, 32]" = torch.ops.aten.relu.default(linear) +print(ep) - # code: out2 = self.branch2(x2) - linear_1: "f32[s0, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias) - relu_1: "f32[s0, 64]" = torch.ops.aten.relu.default(linear_1) - - # code: return (out1 + self.buffer, out2) - add: "f32[s0, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer) - return (add, relu_1) - -Range constraints: {s0: VR[0, int_oo]} +example_args2 = (torch.randn(64, 64), torch.randn(64, 128)) +ep.module()(*example_args2) # success ``` Some additional things to note: - Through the {func}`torch.export.Dim` API and the `dynamic_shapes` argument, we specified the first dimension of each input to be dynamic. Looking at the inputs `x1` and - `x2`, they have a symbolic shape of (s0, 64) and (s0, 128), instead of - the (32, 64) and (32, 128) shaped tensors that we passed in as example inputs. + `x2`, they have a symbolic shape of `(s0, 64)` and `(s0, 128)`, instead of + the `(32, 64)` and `(32, 128)` shaped tensors that we passed in as example inputs. `s0` is a symbol representing that this dimension can be a range of values. - `exported_program.range_constraints` describes the ranges of each symbol @@ -488,436 +255,407 @@ Some additional things to note: [The 0/1 Specialization Problem](https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk) for an in-depth discussion of this topic. -We can also specify more expressive relationships between input shapes, such as -where a pair of shapes might differ by one, a shape might be double of -another, or a shape is even. An example: -```python -class M(torch.nn.Module): - def forward(self, x, y): - return x + y[1:] +In the example, we used `Dim("batch")` to create a dynamic dimension. This is +the most explicit way to specify dynamism. We can also use `Dim.DYNAMIC` and +`Dim.AUTO` to specify dynamism. We will go over both methods in the next section. -x, y = torch.randn(5), torch.randn(6) -dimx = torch.export.Dim("dimx", min=3, max=6) -dimy = dimx + 1 +### Named Dims -exported_program = torch.export.export( - M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}), -) -print(exported_program) -``` +For every dimension specified with `Dim("name")`, we will allocate a symbolic +shape. Specifying a `Dim` with the same name will result in the same symbol +to be generated. This allows users to specify what symbols are allocated for +each input dimension. ```python -ExportedProgram: -class GraphModule(torch.nn.Module): - def forward(self, x: "f32[s0]", y: "f32[s0 + 1]"): - # code: return x + y[1:] - slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(y, 0, 1, 9223372036854775807) - add: "f32[s0]" = torch.ops.aten.add.Tensor(x, slice_1) - return (add,) - -Range constraints: {s0: VR[3, 6], s0 + 1: VR[4, 7]} +batch = Dim("batch") +dynamic_shapes = {"x1": {0: dim}, "x2": {0: batch}} ``` -Some things to note: - -- By specifying `{0: dimx}` for the first input, we see that the resulting - shape of the first input is now dynamic, being `[s0]`. And now by specifying - `{0: dimy}` for the second input, we see that the resulting shape of the - second input is also dynamic. However, because we expressed `dimy = dimx + 1`, - instead of `y`'s shape containing a new symbol, we see that it is - now being represented with the same symbol used in `x`, `s0`. We can - see that relationship of `dimy = dimx + 1` is being shown through `s0 + 1`. -- Looking at the range constraints, we see that `s0` has the range [3, 6], - which is specified initially, and we can see that `s0 + 1` has the solved - range of [4, 7]. - -### Serialization - -To save the `ExportedProgram`, users can use the {func}`torch.export.save` and -{func}`torch.export.load` APIs. A convention is to save the `ExportedProgram` -using a `.pt2` file extension. - -An example: +For each `Dim`, we can specify minimum and maximum values. We also allow +specifying relations between `Dim`s in univariate linear expressions: `A * dim + B`. +This allows users to specify more complex constraints like integer divisibility +for dynamic dimensions. These features allow for users to place explicit +restrictions on the dynamic behavior of the `ExportedProgram` produced. ```python -import torch -import io - -class MyModule(torch.nn.Module): - def forward(self, x): - return x + 10 - -exported_program = torch.export.export(MyModule(), torch.randn(5)) - -torch.export.save(exported_program, 'exported_program.pt2') -saved_exported_program = torch.export.load('exported_program.pt2') +dx = Dim("dx", min=4, max=256) +dh = Dim("dh", max=512) +dynamic_shapes = { + "x": (dx, None), + "y": (2 * dx, dh), +} ``` -### Specializations - -A key concept in understanding the behavior of `torch.export` is the -difference between *static* and *dynamic* values. +However, `ConstraintViolationErrors` will be raised if the while tracing, we emit guards +that conflict with the relations or static/dynamic specifications given. For +example, in the above specification, the following is asserted: -A *dynamic* value is one that can change from run to run. These behave like -normal arguments to a Python function—you can pass different values for an -argument and expect your function to do the right thing. Tensor *data* is -treated as dynamic. +* `x.shape[0]` is to have range `[4, 256]`, and related to `y.shape[0]` by `y.shape[0] == 2 * x.shape[0]`. +* `x.shape[1]` is static. +* `y.shape[1]` has range `[0, 512]`, and is unrelated to any other dimension. -A *static* value is a value that is fixed at export time and cannot change -between executions of the exported program. When the value is encountered during -tracing, the exporter will treat it as a constant and hard-code it into the -graph. +If any of these assertions are found to be incorrect while tracing (ex. +`x.shape[0]` is static, or `y.shape[1]` has a smaller range, or +`y.shape[0] != 2 * x.shape[0]`), then a `ConstraintViolationError` will be +raised, and the user will need to change their `dynamic_shapes` specification. -When an operation is performed (e.g. `x + y`) and all inputs are static, then -the output of the operation will be directly hard-coded into the graph, and the -operation won’t show up (i.e. it will get constant-folded). +### Dim Hints -When a value has been hard-coded into the graph, we say that the graph has been -*specialized* to that value. +Instead of explicitly specifying dynamism using `Dim("name")`, we can let +`torch.export` infer the ranges and relationships of the dynamic values using +`Dim.DYNAMIC`. This is also a more convenient way to specify dynamism when you +don't know specifically *how* dynamic your dynamic values are. -The following values are static: - -#### Input Tensor Shapes +```python +dynamic_shapes = { + "x": (Dim.DYNAMIC, None), + "y": (Dim.DYNAMIC, Dim.DYNAMIC), +} +``` -By default, `torch.export` will trace the program specializing on the input -tensors' shapes, unless a dimension is specified as dynamic via the -`dynamic_shapes` argument to `torch.export`. This means that if there exists -shape-dependent control flow, `torch.export` will specialize on the branch -that is being taken with the given sample inputs. For example: +We can also specify min/max values for `Dim.DYNAMIC`, which will serve as hints +to export. But if while tracing export found the range to be different, it will +automatically update the range without raising an error. We also cannot specify +relationships between dynamic values. Instead, this will be inferred by export, +and exposed to users through an inspection of assertions within the graph. In +this method of specifying dynamism, `ConstraintViolationErrors` will **only** be +raised if the specified value is inferred to be **static**. -```python -import torch -from torch.export import export +An even more convenient way to specify dynamism is to use `Dim.AUTO`, which will +behave like `Dim.DYNAMIC`, but will **not** raise an error if the dimension is +inferred to be static. This is useful for when you have no idea what the dynamic +values are, and want to export the program with a "best effort" dynamic approach. -class Mod(torch.nn.Module): - def forward(self, x): - if x.shape[0] > 5: - return x + 1 - else: - return x - 1 +### ShapesCollection -example_inputs = (torch.rand(10, 2),) -exported_program = export(Mod(), example_inputs) -print(exported_program) -``` +When specifying which inputs are dynamic via `dynamic_shapes`, we must specify +the dynamism of every input. For example, given the following inputs: ```python -ExportedProgram: -class GraphModule(torch.nn.Module): - def forward(self, x: "f32[10, 2]"): - # code: return x + 1 - add: "f32[10, 2]" = torch.ops.aten.add.Tensor(x, 1) - return (add,) +args = {"x": tensor_x, "others": [tensor_y, tensor_z]} ``` -The conditional of (`x.shape[0] > 5`) does not appear in the -`ExportedProgram` because the example inputs have the static -shape of (10, 2). Since `torch.export` specializes on the inputs' static -shapes, the else branch (`x - 1`) will never be reached. To preserve the dynamic -branching behavior based on the shape of a tensor in the traced graph, -{func}`torch.export.Dim` will need to be used to specify the dimension -of the input tensor (`x.shape[0]`) to be dynamic, and the source code will -need to be {ref}`rewritten `. +we would need to specify the dynamism of `tensor_x`, `tensor_y`, and `tensor_z` +along with the dynamic shapes: -Note that tensors that are part of the module state (e.g. parameters and -buffers) always have static shapes. - -#### Python Primitives +```python +# With named-Dims +dim = torch.export.Dim(...) +dynamic_shapes = {"x": {0: dim, 1: dim + 1}, "others": [{0: dim * 2}, None]} -`torch.export` also specializes on Python primitives, -such as `int`, `float`, `bool`, and `str`. However they do have dynamic -variants such as `SymInt`, `SymFloat`, and `SymBool`. +torch.export(..., args, dynamic_shapes=dynamic_shapes) +``` -For example: +However, this is particularly complicated as we need to specify the +`dynamic_shapes` specification in the same nested input structure as the input +arguments. Instead, an easier way to specify dynamic shapes is with the helper +utility {class}`torch.export.ShapesCollection`, where instead of specifying the +dynamism of every single input, we can just assign directly which input +dimensions are dynamic. -```python +```{code-cell} import torch -from torch.export import export -class Mod(torch.nn.Module): - def forward(self, x: torch.Tensor, const: int, times: int): - for i in range(times): - x = x + const - return x +class M(torch.nn.Module): + def forward(self, inp): + x = inp["x"] * 1 + y = inp["others"][0] * 2 + z = inp["others"][1] * 3 + return x, y, z -example_inputs = (torch.rand(2, 2), 1, 3) -exported_program = export(Mod(), example_inputs) -print(exported_program) -``` +tensor_x = torch.randn(3, 4, 8) +tensor_y = torch.randn(6) +tensor_z = torch.randn(6) +args = {"x": tensor_x, "others": [tensor_y, tensor_z]} -```python -ExportedProgram: - class GraphModule(torch.nn.Module): - def forward(self, x: "f32[2, 2]", const, times): - # code: x = x + const - add: "f32[2, 2]" = torch.ops.aten.add.Tensor(x, 1) - add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 1) - add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 1) - return (add_2,) +dim = torch.export.Dim("dim") +sc = torch.export.ShapesCollection() +sc[tensor_x] = (dim, dim + 1, 8) +sc[tensor_y] = {0: dim * 2} + +print(sc.dynamic_shapes(M(), (args,))) +ep = torch.export.export(M(), (args,), dynamic_shapes=sc) +print(ep) ``` -Because integers are specialized, the `torch.ops.aten.add.Tensor` operations -are all computed with the hard-coded constant `1`, rather than `const`. If -a user passes a different value for `const` at runtime, like 2, than the one used -during export time, 1, this will result in an error. -Additionally, the `times` iterator used in the `for` loop is also "inlined" -in the graph through the 3 repeated `torch.ops.aten.add.Tensor` calls, and the -input `times` is never used. +### AdditionalInputs -#### Python Containers +In the case where you don't know how dynamic your inputs are, but you have an +ample set of testing or profiling data that can provide a fair sense of +representative inputs for a model, you can use +{class}`torch.export.AdditionalInputs` in place of `dynamic_shapes`. You can +specify all the possible inputs used to trace the program, and +`AdditionalInputs` will infer which inputs are dynamic based on which input +shapes are changing. -Python containers (`List`, `Dict`, `NamedTuple`, etc.) are considered to -have static structure. +Example: -(limitations-of-torch-export)= +```{code-cell} +import dataclasses +import torch +import torch.utils._pytree as pytree -## Limitations of torch.export +@dataclasses.dataclass +class D: + b: bool + i: int + f: float + t: torch.Tensor -### Graph Breaks +pytree.register_dataclass(D) -As `torch.export` is a one-shot process for capturing a computation graph from -a PyTorch program, it might ultimately run into untraceable parts of programs as -it is nearly impossible to support tracing all PyTorch and Python features. In -the case of `torch.compile`, an unsupported operation will cause a "graph -break" and the unsupported operation will be run with default Python evaluation. -In contrast, `torch.export` will require users to provide additional -information or rewrite parts of their code to make it traceable. As the -tracing is based on TorchDynamo, which evaluates at the Python -bytecode level, there will be significantly fewer rewrites required compared to -previous tracing frameworks. +class M(torch.nn.Module): + def forward(self, d: D): + return d.i + d.f + d.t -When a graph break is encountered, {ref}`ExportDB ` is a great -resource for learning about the kinds of programs that are supported and -unsupported, along with ways to rewrite programs to make them traceable. +input1 = (D(True, 3, 3.0, torch.ones(3)),) +input2 = (D(True, 4, 3.0, torch.ones(4)),) +ai = torch.export.AdditionalInputs() +ai.add(input1) +ai.add(input2) -An option to get past dealing with this graph breaks is by using -{ref}`non-strict export ` +print(ai.dynamic_shapes(M(), input1)) +ep = torch.export.export(M(), input1, dynamic_shapes=ai) +print(ep) +``` -(data-shape-dependent-control-flow)= +## Serialization -### Data/Shape-Dependent Control Flow +To save the `ExportedProgram`, users can use the {func}`torch.export.save` and +{func}`torch.export.load` APIs. The resulting file is a zipfile with a specific +structure. The details of the structure are defined in the +{ref}`PT2 Archive Spec `. -Graph breaks can also be encountered on data-dependent control flow (`if -x.shape[0] > 2`) when shapes are not being specialized, as a tracing compiler cannot -possibly deal with without generating code for a combinatorially exploding -number of paths. In such cases, users will need to rewrite their code using -special control flow operators. Currently, we support {ref}`torch.cond ` -to express if-else like control flow (more coming soon!). +An example: -### Missing Fake/Meta/Abstract Kernels for Operators +```python +import torch -When tracing, a FakeTensor kernel (aka meta kernel, abstract impl) is -required for all operators. This is used to reason about the input/output shapes -for this operator. +class MyModule(torch.nn.Module): + def forward(self, x): + return x + 10 -Please see {func}`torch.library.register_fake` for more details. +exported_program = torch.export.export(MyModule(), (torch.randn(5),)) -In the unfortunate case where your model uses an ATen operator that is does not -have a FakeTensor kernel implementation yet, please file an issue. +torch.export.save(exported_program, 'exported_program.pt2') +saved_exported_program = torch.export.load('exported_program.pt2') +``` -## Read More +(training-export)= -```{toctree} -:caption: Additional Links for Export Users -:maxdepth: 1 +## Export IR, Decompositions -export.programming_model -export.ir_spec -draft_export -torch.compiler_transformations -torch.compiler_ir -generated/exportdb/index -cond -``` +The graph produced by `torch.export` returns a graph containing only +[ATen operators](https://pytorch.org/cppdocs/#aten), which are the basic unit of +computation in PyTorch. As there are over +3000 ATen operators, export provides a way to narrow down the operator set used +in the graph based on certain characteristics, creating different IRs. -```{toctree} -:caption: Deep Dive for PyTorch Developers -:maxdepth: 1 +By default, export produces the most generic IR which contains all ATen +operators, including both functional and non-functional operators. A functional +operator is one that does not contain any mutations or aliasing of the inputs. +You can find a list of all ATen operators +[here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml) +and you can inspect if an operator is functional by checking +`op._schema.is_mutable`. -torch.compiler_dynamo_overview -torch.compiler_dynamo_deepdive -torch.compiler_dynamic_shapes -torch.compiler_fake_tensor -``` +This generic IR can be used to train in eager PyTorch Autograd. -## API Reference +```{code-cell} +import torch -```{eval-rst} -.. automodule:: torch.export -``` +class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(1, 3, 1, 1) + self.bn = torch.nn.BatchNorm2d(3) -```{eval-rst} -.. autofunction:: export -``` + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return (x,) -```{eval-rst} -.. autofunction:: save +ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),)) +print(ep_for_training.graph_module.print_readable(print_output=False)) ``` -```{eval-rst} -.. autofunction:: load -``` +However, if you want to use the IR for inference, or decrease the amount of +operators being used, you can lower the graph through the +{func}`ExportedProgram.run_decompositions` API. This method decomposes the +ATen operators into the ones specified in the decomposition table, and +functionalizes the graph. -```{eval-rst} -.. autofunction:: draft_export -``` +By specifying an empty set, we're only performing functionalization, and does +not do any additional decompositions. This results in an IR which contains ~2000 +operators (instead of the 3000 operators above), and is ideal for inference cases. -```{eval-rst} -.. autofunction:: register_dataclass -``` +```{code-cell} +import torch -```{eval-rst} -.. autoclass:: torch.export.dynamic_shapes.Dim -``` +class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(1, 3, 1, 1) + self.bn = torch.nn.BatchNorm2d(3) -```{eval-rst} -.. autoclass:: torch.export.dynamic_shapes.ShapesCollection + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return (x,) - .. automethod:: dynamic_shapes +ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),)) +with torch.no_grad(): + ep_for_inference = ep_for_training.run_decompositions(decomp_table={}) +print(ep_for_inference.graph_module.print_readable(print_output=False)) ``` -```{eval-rst} -.. autoclass:: torch.export.dynamic_shapes.AdditionalInputs +As we can see, the previously in-place operator, +`torch.ops.aten.add_.default` has now been replaced with +`torch.ops.aten.add.default`, a functional operator. - .. automethod:: add - .. automethod:: dynamic_shapes - .. automethod:: verify -``` +We can also further lower this exported program to an operator set which only +contains the +`Core ATen Operator Set `__, +which is a collection of only ~180 operators. This IR is optimal for backends +who do not want to reimplement all ATen operators. -```{eval-rst} -.. autofunction:: torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes -``` +```{code-cell} +import torch -```{eval-rst} -.. autoclass:: ExportedProgram - - .. attribute:: graph - .. attribute:: graph_signature - .. attribute:: state_dict - .. attribute:: constants - .. attribute:: range_constraints - .. attribute:: module_call_graph - .. attribute:: example_inputs - .. automethod:: module - .. automethod:: run_decompositions -``` +class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(1, 3, 1, 1) + self.bn = torch.nn.BatchNorm2d(3) -```{eval-rst} -.. autoclass:: ExportGraphSignature -``` + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return (x,) -```{eval-rst} -.. autoclass:: ModuleCallSignature +ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),)) +with torch.no_grad(): + core_aten_ir = ep_for_training.run_decompositions(decomp_table=None) +print(core_aten_ir.graph_module.print_readable(print_output=False)) ``` -```{eval-rst} -.. autoclass:: ModuleCallEntry -``` +We now see that `torch.ops.aten.conv2d.default` has been decomposed +into `torch.ops.aten.convolution.default`. This is because `convolution` +is a more "core" operator, as operations like `conv1d` and `conv2d` can be +implemented using the same op. -```{eval-rst} -.. automodule:: torch.export.decomp_utils -``` +We can also specify our own decomposition behaviors: -```{eval-rst} -.. autoclass:: CustomDecompTable +```{code-cell} +class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(1, 3, 1, 1) + self.bn = torch.nn.BatchNorm2d(3) - .. automethod:: copy - .. automethod:: items - .. automethod:: keys - .. automethod:: materialize - .. automethod:: pop - .. automethod:: update -``` + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return (x,) -```{eval-rst} -.. autofunction:: torch.export.exported_program.default_decompositions -``` +ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),)) -```{eval-rst} -.. automodule:: torch.export.exported_program -``` +my_decomp_table = torch.export.default_decompositions() -```{eval-rst} -.. automodule:: torch.export.graph_signature +def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1): + return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups) + +my_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function +my_ep = ep_for_training.run_decompositions(my_decomp_table) +print(my_ep.graph_module.print_readable(print_output=False)) ``` -```{eval-rst} -.. autoclass:: ExportGraphSignature +Notice that instead of `torch.ops.aten.conv2d.default` being decomposed +into `torch.ops.aten.convolution.default`, it is now decomposed into +`torch.ops.aten.convolution.default` and `torch.ops.aten.mul.Tensor`, +which matches our custom decomposition rule. - .. automethod:: replace_all_uses - .. automethod:: get_replace_hook -``` +(limitations-of-torch-export)= -```{eval-rst} -.. autoclass:: ExportBackwardSignature -``` +## Limitations of torch.export -```{eval-rst} -.. autoclass:: InputKind -``` +As `torch.export` is a one-shot process for capturing a computation graph from +a PyTorch program, it might ultimately run into untraceable parts of programs as +it is nearly impossible to support tracing all PyTorch and Python features. In +the case of `torch.compile`, an unsupported operation will cause a "graph +break" and the unsupported operation will be run with default Python evaluation. +In contrast, `torch.export` will require users to provide additional +information or rewrite parts of their code to make it traceable. -```{eval-rst} -.. autoclass:: InputSpec -``` +{ref}`Draft-export ` is a great resource for listing out +graphs breaks that will be encountered when tracing the program, along with +additional debug information to solve those errors. -```{eval-rst} -.. autoclass:: OutputKind -``` +{ref}`ExportDB ` is also great resource for learning about the +kinds of programs that are supported and unsupported, along with ways to rewrite +programs to make them traceable. -```{eval-rst} -.. autoclass:: OutputSpec -``` +### TorchDynamo unsupported -```{eval-rst} -.. autoclass:: SymIntArgument -``` +When using `torch.export` with `strict=True`, this will use TorchDynamo to +evaluate the program at the Python bytecode level to trace the program into a +graph. Compared to previous tracing frameworks, there will be significantly +fewer rewrites required to make a program traceable, but there will still be +some Python features that are unsupported. An option to get past dealing with +this graph breaks is by using +{ref}`non-strict export ` through changing the `strict` flag +to `strict=False`. -```{eval-rst} -.. autoclass:: SymBoolArgument -``` +(data-shape-dependent-control-flow)= -```{eval-rst} -.. autoclass:: SymFloatArgument -``` +### Data/Shape-Dependent Control Flow -```{eval-rst} -.. autoclass:: CustomObjArgument -``` +Graph breaks can also be encountered on data-dependent control flow (`if +x.shape[0] > 2`) when shapes are not being specialized, as a tracing compiler cannot +possibly deal with without generating code for a combinatorially exploding +number of paths. In such cases, users will need to rewrite their code using +special control flow operators. Currently, we support {ref}`torch.cond ` +to express if-else like control flow (more coming soon!). -```{eval-rst} -.. py:module:: torch.export.dynamic_shapes -``` +You can also refer to this +[tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_export_tutorial.html#data-dependent-errors) +for more ways of addressing data-dependent errors. -```{eval-rst} -.. py:module:: torch.export.custom_ops -``` +### Missing Fake/Meta Kernels for Operators -```{eval-rst} -.. automodule:: torch.export.unflatten - :members: -``` +When tracing, a FakeTensor kernel (aka meta kernel) is required for all +operators. This is used to reason about the input/output shapes for this +operator. -```{eval-rst} -.. automodule:: torch.export.custom_obj -``` +Please see this [tutorial](https://docs.pytorch.org/tutorials/advanced/custom_ops_landing_page.html) +for more details. -```{eval-rst} -.. automodule:: torch.export.experimental -``` +In the unfortunate case where your model uses an ATen operator that is does not +have a FakeTensor kernel implementation yet, please file an issue. -```{eval-rst} -.. automodule:: torch.export.passes -``` +## Read More -```{eval-rst} -.. autofunction:: torch.export.passes.move_to_device_pass -``` +```{toctree} +:caption: Additional Links for Export Users +:maxdepth: 1 -```{eval-rst} -.. automodule:: torch.export.pt2_archive +export/api_reference +export/programming_model +export/ir_spec +export/pt2_archive +export/draft_export +cond +generated/exportdb/index +torch.compiler_aot_inductor +torch.compiler_ir ``` -```{eval-rst} -.. automodule:: torch.export.pt2_archive.constants +```{toctree} +:caption: Deep Dive for PyTorch Developers +:maxdepth: 1 + +torch.compiler_dynamic_shapes +torch.compiler_fake_tensor +torch.compiler_transformations ``` diff --git a/docs/source/export/api_reference.md b/docs/source/export/api_reference.md new file mode 100644 index 000000000000..f729e84e261d --- /dev/null +++ b/docs/source/export/api_reference.md @@ -0,0 +1,69 @@ +(export.api_reference)= + +# torch.export API Reference + +```{eval-rst} +.. automodule:: torch.export + +.. autofunction:: torch.export.export + +.. autoclass:: torch.export.ExportedProgram + :members: + :exclude-members: __init__ + +.. automodule:: torch.export.dynamic_shapes + :members: Dim, ShapesCollection, AdditionalInputs, refine_dynamic_shapes_from_suggested_fixes + +.. autofunction:: torch.export.save + +.. autofunction:: torch.export.load + +.. autofunction:: torch.export.pt2_archive._package.package_pt2 + +.. autofunction:: torch.export.pt2_archive._package.load_pt2 + +.. autofunction:: torch.export.draft_export + +.. automodule:: torch.export.unflatten + :members: + +.. autofunction:: torch.export.register_dataclass + +.. automodule:: torch.export.decomp_utils + :members: + :ignore-module-all: + :undoc-members: + +.. automodule:: torch.export.experimental + :members: + :ignore-module-all: + +.. automodule:: torch.export.passes + :members: + +.. automodule:: torch.export.pt2_archive + :members: + :ignore-module-all: + +.. automodule:: torch.export.pt2_archive.constants + :members: + :ignore-module-all: + +.. automodule:: torch.export.exported_program + :members: + :ignore-module-all: + :exclude-members: ExportedProgram + +.. automodule:: torch.export.custom_ops + :members: + :ignore-module-all: + +.. automodule:: torch.export.custom_obj + :members: + :ignore-module-all: + +.. automodule:: torch.export.graph_signature + :members: + :ignore-module-all: + :undoc-members: +``` diff --git a/docs/source/draft_export.md b/docs/source/export/draft_export.md similarity index 97% rename from docs/source/draft_export.md rename to docs/source/export/draft_export.md index cc7247d3b526..b1ec6ca5d44e 100644 --- a/docs/source/draft_export.md +++ b/docs/source/export/draft_export.md @@ -1,4 +1,4 @@ -(draft-export)= +(export.draft_export)= # Draft Export @@ -126,7 +126,7 @@ Running the `tlparse` command in the terminal will generate a [tlparse](https://github.com/pytorch/tlparse) HTML report. Here is an example of the `tlparse` report: -```{image} _static/img/export/draft_export_report.png +```{image} ../_static/img/export/draft_export_report.png ``` Clicking into the Data Dependent Error, we will see the following page which @@ -136,7 +136,7 @@ contains information to help debug this error. Specifically, it contains: - A list of local variables and their shapes - Information for how this guard was created -```{image} _static/img/export/draft_export_report_dde.png +```{image} ../_static/img/export/draft_export_report_dde.png ``` ## The returned Exported Program @@ -251,12 +251,3 @@ and produce a runnable artifact. This optimized version can then be used for deployment. In parallel, we can utilize the report generated by draft-export to identify and fix `torch.export` errors that were encountered so that the original model can be directly traceable with `torch.export`. - -```{toctree} -:caption: Additional Links -:maxdepth: 1 - -torch.compiler_fake_tensor -torch.compiler_dynamic_shapes -torch.compiler_aot_inductor -``` diff --git a/docs/source/export.ir_spec.md b/docs/source/export/ir_spec.md similarity index 100% rename from docs/source/export.ir_spec.md rename to docs/source/export/ir_spec.md diff --git a/docs/source/export.programming_model.md b/docs/source/export/programming_model.md similarity index 98% rename from docs/source/export.programming_model.md rename to docs/source/export/programming_model.md index 9a21db78464a..d4b81b223fa2 100644 --- a/docs/source/export.programming_model.md +++ b/docs/source/export/programming_model.md @@ -1,4 +1,4 @@ -(export-programming-model)= +(export.programming_model)= # torch.export Programming Model @@ -15,7 +15,9 @@ on different inputs as long as they satisfy the same conditions. The basic output of {func}`torch.export.export` is a single graph of PyTorch operations, with associated metadata. The exact format of this output is -covered in the {ref}`export.ir_spec`. +covered in the {ref}`export IR spec `. + +(non-strict-export)= ### Strict vs. Non-Strict Tracing @@ -120,6 +122,9 @@ Whether a value is static or dynamic depends on its type: - There are dynamic variants for some primitive types (`SymInt`, `SymFloat`, `SymBool`). Typically users do not have to deal with them. + - Users can specify integer inputs as dynamic by specifying + a [dynamic shape](https://pytorch.org/docs/main/export.html#expressing-dynamism) + for it. - For Python *standard containers* (`list`, `tuple`, `dict`, `namedtuple`): @@ -150,7 +155,7 @@ By default, the types of inputs you can use for your program are: - Python primitives (`int`, `float`, `bool`, `str`, `None`) - Python standard containers (`list`, `tuple`, `dict`, `namedtuple`) -### Custom Input Types +### Custom Input Types (PyTree) In addition, you can also define your own (custom) class and use it as an input type, but you will need to register such a class as a PyTree. @@ -164,7 +169,8 @@ class Input: f: torch.Tensor p: torch.Tensor -torch.export.register_dataclass(Input) +import torch.utils._pytree as pytree +pytree.register_dataclass(Input) class M(torch.nn.Module): def forward(self, x: Input): diff --git a/docs/source/export/pt2_archive.md b/docs/source/export/pt2_archive.md new file mode 100644 index 000000000000..cfb589f7bdfe --- /dev/null +++ b/docs/source/export/pt2_archive.md @@ -0,0 +1,122 @@ +(export.pt2_archive)= + +# PT2 Archive Spec + +The following specification defines the archive format which can be produced +through the following methods: + +* {ref}`torch.export ` through calling {func}`torch.export.save` +* {ref}`AOTInductor ` through calling {func}`torch._inductor.aoti_compile_and_package` + +The archive is a zipfile, and can be manipulated using standard zipfile APIs. + +The following is a sample archive. We will walk through the archive folder by folder. + +``` +. +├── archive_format +├── byteorder +├── .data +│ ├── serialization_id +│ └── version +├── data +│ ├── aotinductor +│ │ └── model1 +│ │ ├── aotinductor_pickle_data.json +│ │ ├── cf5ez6ifexr7i2hezzz4s7xfusj4wtisvu2gddeamh37bw6bghjw.cpp +│ │ ├── cf5ez6ifexr7i2hezzz4s7xfusj4wtisvu2gddeamh37bw6bghjw.so +│ │ ├── cg7domx3woam3nnliwud7yvtcencqctxkvvcafuriladwxw4nfiv.cubin +│ │ └── cubaaxppb6xmuqdm4bej55h2pftbce3bjyyvljxbtdfuolmv45ex.cubin +│ ├── weights +│ │ ├── model1_model_param_config.json +│ │ ├── weight_0 +│ │ ├── weight_1 +│ │ ├── weight_2 +│ └── constants +│ │ ├── model1_model_constants_config.json +│ │ ├── tensor_0 +│ │ ├── tensor_1 +│ │ ├── custom_obj_0 +│ │ ├── custom_obj_1 +│ └── sample_inputs +│ ├── model1.pt +│ └── model2.pt +├── extra +│ └── ....json +└── models + ├── model1.json + └── model2.json +``` + +## Contents + +### Archive Headers + +* `archive_format` declares the format used by this archive. Currently, it can only be “pt2”. +* `byteorder`. One of “little” or “big”, used by zip file reader +* `/.data/version` contains the archive version. (Notice that this is neither export serialization’s schema version, nor Aten Opset Version). +* `/.data/serialization_id` is a hash generated for the current archive, used for verification. + + +### AOTInductor Compiled Artifact + +Path: `/data/aotinductor/-/` + +AOTInductor compilation artifacts are saved for each model-backend pair. For +example, compilation artifacts for the `model1` model on A100 and H100 will be +saved in `model1-a100` and `model1-h100` folders separately. + +The folder typically contains +* `.so`: Dynamic library compiled from .cpp. +* `.cpp`: AOTInductor generated cpp wrapper file. +* `*.cubin`: Triton kernels compiled from triton codegen kernels +* (optional) `.json`: External fallback nodes for custom ops to be executed by `ProxyExecutor`, serialized according to `ExternKernelNode` struct. If the model doesn’t use custom ops/ProxyExecutor, this file would be omitted. +* `_metadata.json`: Metadata which was passed in from the `aot_inductor.metadata` inductor config + +### Weights + +Path: `/data/weights/*` + +Model parameters and buffers are saved in the `/data/weights/` folder. Each +tensor is saved as a separated file. The file only contains the raw data blob, +tensor metadata are saved separately in the +`_model_param_config.json`. + +### Constants + +Path: `/data/constants/*` + +TensorConstants, non-persistent buffers and TorchBind objects are saved in the +`/data/constants/` folder. Metadata is saved separately in the +`_model_constants_config.json` + +### Sample Inputs + +Path: `/data/sample_inputs/.pt` + +The `sample_input` used by `torch.export` could be included in the archive for +downstream use. Typically, it’s a flattened list of Tensors, combining both args +and kwargs of the forward() function. + +The .pt file is produced by `torch.save(sample_input)`, and can be loaded by +`torch.load()` in python and `torch::pickle_load()` in c++. + +When the model has multiple copies of sample input, it would be packaged as +`_.pt`. + +### Models Definitions + +Path: `/models/.json` + +Model definition is the serialized json of the ExportedProgram from +`torch.export.save`, and other model-level metadata. + +## Multiple Models + +This archive spec supports multiple model definitions coexisting in the same +file, with `` serving as a unique identifier for the models, and +will be used as reference in other folders of the archive. + +Lower level APIs like {func}`torch.export.pt2_archive._package.package_pt2` and +{func}`torch.export.pt2_archive._package.load_pt2` allow you to have +finer-grained control over the packaging and loading process. diff --git a/docs/source/fx.md b/docs/source/fx.md index 8b60c8064966..831534606abe 100644 --- a/docs/source/fx.md +++ b/docs/source/fx.md @@ -44,8 +44,7 @@ Your transform will take in a {class}`torch.nn.Module`, acquire a {class}`Graph` from it, do some modifications, and return a new {class}`torch.nn.Module`. You should think of the {class}`torch.nn.Module` that your FX transform returns as identical to a regular {class}`torch.nn.Module` -- you can pass it to another -FX transform, you can pass it to TorchScript, or you can -run it. Ensuring that the inputs and outputs of your FX transform are a +FX transform, or you can run it. Ensuring that the inputs and outputs of your FX transform are a {class}`torch.nn.Module` will allow for composability. ```{note} diff --git a/docs/source/jit.rst b/docs/source/jit.rst index c5ba9063a50c..5295f82f9ac1 100644 --- a/docs/source/jit.rst +++ b/docs/source/jit.rst @@ -1,48 +1,21 @@ TorchScript =========== -.. toctree:: - :maxdepth: 1 - :caption: Builtin Functions - :hidden: - - torch.jit.supported_ops - - .. toctree:: :maxdepth: 1 - :caption: Language Reference :hidden: jit_language_reference - - -.. toctree:: - :maxdepth: 1 - jit_language_reference_v2 + jit_python_reference + jit_unsupported + torch.jit.supported_ops - -.. contents:: :local: - :depth: 2 +.. warning:: + TorchScript is deprecated, please use + `torch.export `__ instead. .. automodule:: torch.jit -.. currentmodule:: torch.jit - -TorchScript is a way to create serializable and optimizable models from PyTorch code. -Any TorchScript program can be saved from a Python -process and loaded in a process where there is no Python dependency. - -We provide tools to incrementally transition a model from a pure Python program -to a TorchScript program that can be run independently from Python, such as in a standalone C++ program. -This makes it possible to train models in PyTorch using familiar tools in Python and then export -the model via TorchScript to a production environment where Python programs may be disadvantageous -for performance and multi-threading reasons. - -For a gentle introduction to TorchScript, see the `Introduction to TorchScript `_ tutorial. - -For an end-to-end example of converting a PyTorch model to TorchScript and running it in C++, see the -`Loading a PyTorch Model in C++ `_ tutorial. Creating TorchScript Code -------------------------- @@ -74,817 +47,11 @@ Creating TorchScript Code Attribute annotate -Mixing Tracing and Scripting ----------------------------- - -In many cases either tracing or scripting is an easier approach for converting a model to TorchScript. -Tracing and scripting can be composed to suit the particular requirements -of a part of a model. - -Scripted functions can call traced functions. This is particularly useful when you need -to use control-flow around a simple feed-forward model. For instance the beam search -of a sequence to sequence model will typically be written in script but can call an -encoder module generated using tracing. - - -.. testsetup:: - - # These are hidden from the docs, but these are necessary for `doctest` - # since the `inspect` module doesn't play nicely with the execution - # environment for `doctest` - import torch - - original_script = torch.jit.script - def script_wrapper(obj, *args, **kwargs): - obj.__module__ = 'FakeMod' - return original_script(obj, *args, **kwargs) - - torch.jit.script = script_wrapper - - original_trace = torch.jit.trace - def trace_wrapper(obj, *args, **kwargs): - obj.__module__ = 'FakeMod' - return original_trace(obj, *args, **kwargs) - - torch.jit.trace = trace_wrapper - - -Example (calling a traced function in script): - -.. testcode:: - - import torch - - def foo(x, y): - return 2 * x + y - - traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) - - @torch.jit.script - def bar(x): - return traced_foo(x, x) - -Traced functions can call script functions. This is useful when a small part of -a model requires some control-flow even though most of the model is just a feed-forward -network. Control-flow inside of a script function called by a traced function is -preserved correctly. - -Example (calling a script function in a traced function): - -.. testcode:: - - import torch - - @torch.jit.script - def foo(x, y): - if x.max() > y.max(): - r = x - else: - r = y - return r - - - def bar(x, y, z): - return foo(x, y) + z - - traced_bar = torch.jit.trace(bar, (torch.rand(3), torch.rand(3), torch.rand(3))) - -This composition also works for ``nn.Module``\s as well, where it can be used to generate -a submodule using tracing that can be called from the methods of a script module. - -Example (using a traced module): - -.. testcode:: - :skipif: torchvision is None - - import torch - import torchvision - - class MyScriptModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68]) - .resize_(1, 3, 1, 1)) - self.resnet = torch.jit.trace(torchvision.models.resnet18(), - torch.rand(1, 3, 224, 224)) - - def forward(self, input): - return self.resnet(input - self.means) - - my_script_module = torch.jit.script(MyScriptModule()) - - -TorchScript Language --------------------- - -TorchScript is a statically typed subset of Python, so many Python features apply -directly to TorchScript. See the full :ref:`language-reference` for details. - - -.. _builtin functions: - -Built-in Functions and Modules ------------------------------- - -TorchScript supports the use of most PyTorch functions and many Python built-ins. -See :ref:`builtin-functions` for a full reference of supported functions. - -PyTorch Functions and Modules -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -TorchScript supports a subset of the tensor and neural network -functions that PyTorch provides. Most methods on Tensor as well as functions in -the ``torch`` namespace, all functions in ``torch.nn.functional`` and -most modules from ``torch.nn`` are supported in TorchScript. - -See :ref:`jit_unsupported` for a list of unsupported PyTorch functions and modules. - - -Python Functions and Modules -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Many of Python's `built-in functions `_ are supported in TorchScript. -The :any:`math` module is also supported (see :ref:`math-module` for details), but no other Python modules -(built-in or third party) are supported. - - -Python Language Reference Comparison -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -For a full listing of supported Python features, see :ref:`python-language-reference`. - -Debugging ---------- - -.. _`disable TorchScript`: - -Disable JIT for Debugging -~~~~~~~~~~~~~~~~~~~~~~~~~ -.. envvar:: PYTORCH_JIT - -Setting the environment variable ``PYTORCH_JIT=0`` will disable all script -and tracing annotations. If there is hard-to-debug error in one of your -TorchScript models, you can use this flag to force everything to run using native -Python. Since TorchScript (scripting and tracing) is disabled with this flag, -you can use tools like ``pdb`` to debug the model code. For example:: - - @torch.jit.script - def scripted_fn(x : torch.Tensor): - for i in range(12): - x = x + x - return x - - def fn(x): - x = torch.neg(x) - import pdb; pdb.set_trace() - return scripted_fn(x) - - traced_fn = torch.jit.trace(fn, (torch.rand(4, 5),)) - traced_fn(torch.rand(3, 4)) - -Debugging this script with ``pdb`` works except for when we invoke the -:func:`@torch.jit.script ` function. We can globally disable -JIT, so that we can call the :func:`@torch.jit.script ` -function as a normal Python function and not compile it. If the above script -is called ``disable_jit_example.py``, we can invoke it like so:: - - $ PYTORCH_JIT=0 python disable_jit_example.py - -and we will be able to step into the :func:`@torch.jit.script -` function as a normal Python function. To disable the -TorchScript compiler for a specific function, see -:func:`@torch.jit.ignore `. - -.. _inspecting-code: - -Inspecting Code -~~~~~~~~~~~~~~~ - -TorchScript provides a code pretty-printer for all :class:`ScriptModule` instances. This -pretty-printer gives an interpretation of the script method's code as valid -Python syntax. For example: - -.. testcode:: - - @torch.jit.script - def foo(len): - # type: (int) -> torch.Tensor - rv = torch.zeros(3, 4) - for i in range(len): - if i < 10: - rv = rv - 1.0 - else: - rv = rv + 1.0 - return rv - - print(foo.code) - -.. testoutput:: - :hide: - - ... - -A :class:`ScriptModule` with a single ``forward`` method will have an attribute -``code``, which you can use to inspect the :class:`ScriptModule`'s code. -If the :class:`ScriptModule` has more than one method, you will need to access -``.code`` on the method itself and not the module. We can inspect the -code of a method named ``foo`` on a :class:`ScriptModule` by accessing ``.foo.code``. -The example above produces this output: :: - - def foo(len: int) -> Tensor: - rv = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None) - rv0 = rv - for i in range(len): - if torch.lt(i, 10): - rv1 = torch.sub(rv0, 1., 1) - else: - rv1 = torch.add(rv0, 1., 1) - rv0 = rv1 - return rv0 - -This is TorchScript's compilation of the code for the ``forward`` method. -You can use this to ensure TorchScript (tracing or scripting) has captured -your model code correctly. - - -.. _interpreting-graphs: - -Interpreting Graphs -~~~~~~~~~~~~~~~~~~~ -TorchScript also has a representation at a lower level than the code pretty-\ -printer, in the form of IR graphs. - -TorchScript uses a static single assignment (SSA) intermediate representation -(IR) to represent computation. The instructions in this format consist of -ATen (the C++ backend of PyTorch) operators and other primitive operators, -including control flow operators for loops and conditionals. As an example: - -.. testcode:: - - @torch.jit.script - def foo(len): - # type: (int) -> torch.Tensor - rv = torch.zeros(3, 4) - for i in range(len): - if i < 10: - rv = rv - 1.0 - else: - rv = rv + 1.0 - return rv - - print(foo.graph) - -.. testoutput:: - :hide: - - ... - -``graph`` follows the same rules described in the :ref:`inspecting-code` section -with regard to ``forward`` method lookup. - -The example script above produces the graph:: - - graph(%len.1 : int): - %24 : int = prim::Constant[value=1]() - %17 : bool = prim::Constant[value=1]() # test.py:10:5 - %12 : bool? = prim::Constant() - %10 : Device? = prim::Constant() - %6 : int? = prim::Constant() - %1 : int = prim::Constant[value=3]() # test.py:9:22 - %2 : int = prim::Constant[value=4]() # test.py:9:25 - %20 : int = prim::Constant[value=10]() # test.py:11:16 - %23 : float = prim::Constant[value=1]() # test.py:12:23 - %4 : int[] = prim::ListConstruct(%1, %2) - %rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10 - %rv : Tensor = prim::Loop(%len.1, %17, %rv.1) # test.py:10:5 - block0(%i.1 : int, %rv.14 : Tensor): - %21 : bool = aten::lt(%i.1, %20) # test.py:11:12 - %rv.13 : Tensor = prim::If(%21) # test.py:11:9 - block0(): - %rv.3 : Tensor = aten::sub(%rv.14, %23, %24) # test.py:12:18 - -> (%rv.3) - block1(): - %rv.6 : Tensor = aten::add(%rv.14, %23, %24) # test.py:14:18 - -> (%rv.6) - -> (%17, %rv.13) - return (%rv) - - -Take the instruction ``%rv.1 : Tensor = aten::zeros(%4, %6, %6, %10, %12) # test.py:9:10`` for -example. - -* ``%rv.1 : Tensor`` means we assign the output to a (unique) value named ``rv.1``, that value is of ``Tensor`` type and that we do not know its concrete shape. -* ``aten::zeros`` is the operator (equivalent to ``torch.zeros``) and the input list ``(%4, %6, %6, %10, %12)`` specifies which values in scope should be passed as inputs. The schema for built-in functions like ``aten::zeros`` can be found at `Builtin Functions`_. -* ``# test.py:9:10`` is the location in the original source file that generated this instruction. In this case, it is a file named `test.py`, on line 9, and at character 10. - -Notice that operators can also have associated ``blocks``, namely the -``prim::Loop`` and ``prim::If`` operators. In the graph print-out, these -operators are formatted to reflect their equivalent source code forms -to facilitate easy debugging. - -Graphs can be inspected as shown to confirm that the computation described -by a :class:`ScriptModule` is correct, in both automated and manual fashion, as -described below. - -Tracer -~~~~~~ - - -Tracing Edge Cases -^^^^^^^^^^^^^^^^^^ -There are some edge cases that exist where the trace of a given Python -function/module will not be representative of the underlying code. These -cases can include: - -* Tracing of control flow that is dependent on inputs (e.g. tensor shapes) -* Tracing of in-place operations of tensor views (e.g. indexing on the left-hand side of an assignment) - -Note that these cases may in fact be traceable in the future. - - -Automatic Trace Checking -^^^^^^^^^^^^^^^^^^^^^^^^ -One way to automatically catch many errors in traces is by using ``check_inputs`` -on the ``torch.jit.trace()`` API. ``check_inputs`` takes a list of tuples -of inputs that will be used to re-trace the computation and verify the -results. For example:: - - def loop_in_traced_fn(x): - result = x[0] - for i in range(x.size(0)): - result = result * x[i] - return result - - inputs = (torch.rand(3, 4, 5),) - check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)] - - traced = torch.jit.trace(loop_in_traced_fn, inputs, check_inputs=check_inputs) - -Gives us the following diagnostic information:: - - ERROR: Graphs differed across invocations! - Graph diff: - - graph(%x : Tensor) { - %1 : int = prim::Constant[value=0]() - %2 : int = prim::Constant[value=0]() - %result.1 : Tensor = aten::select(%x, %1, %2) - %4 : int = prim::Constant[value=0]() - %5 : int = prim::Constant[value=0]() - %6 : Tensor = aten::select(%x, %4, %5) - %result.2 : Tensor = aten::mul(%result.1, %6) - %8 : int = prim::Constant[value=0]() - %9 : int = prim::Constant[value=1]() - %10 : Tensor = aten::select(%x, %8, %9) - - %result : Tensor = aten::mul(%result.2, %10) - + %result.3 : Tensor = aten::mul(%result.2, %10) - ? ++ - %12 : int = prim::Constant[value=0]() - %13 : int = prim::Constant[value=2]() - %14 : Tensor = aten::select(%x, %12, %13) - + %result : Tensor = aten::mul(%result.3, %14) - + %16 : int = prim::Constant[value=0]() - + %17 : int = prim::Constant[value=3]() - + %18 : Tensor = aten::select(%x, %16, %17) - - %15 : Tensor = aten::mul(%result, %14) - ? ^ ^ - + %19 : Tensor = aten::mul(%result, %18) - ? ^ ^ - - return (%15); - ? ^ - + return (%19); - ? ^ - } - - -This message indicates to us that the computation differed between when -we first traced it and when we traced it with the ``check_inputs``. Indeed, -the loop within the body of ``loop_in_traced_fn`` depends on the shape -of the input ``x``, and thus when we try another ``x`` with a different -shape, the trace differs. - -In this case, data-dependent control flow like this can be captured using -:func:`torch.jit.script` instead: - -.. testcode:: - - def fn(x): - result = x[0] - for i in range(x.size(0)): - result = result * x[i] - return result - - inputs = (torch.rand(3, 4, 5),) - check_inputs = [(torch.rand(4, 5, 6),), (torch.rand(2, 3, 4),)] - - scripted_fn = torch.jit.script(fn) - print(scripted_fn.graph) - #print(str(scripted_fn.graph).strip()) - - for input_tuple in [inputs] + check_inputs: - torch.testing.assert_close(fn(*input_tuple), scripted_fn(*input_tuple)) - -.. testoutput:: - :hide: - - ... - - -Which produces:: - - graph(%x : Tensor) { - %5 : bool = prim::Constant[value=1]() - %1 : int = prim::Constant[value=0]() - %result.1 : Tensor = aten::select(%x, %1, %1) - %4 : int = aten::size(%x, %1) - %result : Tensor = prim::Loop(%4, %5, %result.1) - block0(%i : int, %7 : Tensor) { - %10 : Tensor = aten::select(%x, %1, %i) - %result.2 : Tensor = aten::mul(%7, %10) - -> (%5, %result.2) - } - return (%result); - } - -Tracer Warnings -^^^^^^^^^^^^^^^ -The tracer produces warnings for several problematic patterns in traced -computation. As an example, take a trace of a function that contains an -in-place assignment on a slice (a view) of a Tensor: - -.. testcode:: - - def fill_row_zero(x): - x[0] = torch.rand(*x.shape[1:2]) - return x - - traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),)) - print(traced.graph) - -.. testoutput:: - :hide: - - ... - -Produces several warnings and a graph which simply returns the input:: - - fill_row_zero.py:4: TracerWarning: There are 2 live references to the data region being modified when tracing in-place operator copy_ (possibly due to an assignment). This might cause the trace to be incorrect, because all other views that also reference this data will not reflect this change in the trace! On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. are outputs of torch.split), this might still be safe. - x[0] = torch.rand(*x.shape[1:2]) - fill_row_zero.py:6: TracerWarning: Output nr 1. of the traced function does not match the corresponding output of the Python function. Detailed error: - Not within tolerance rtol=1e-05 atol=1e-05 at input[0, 1] (0.09115803241729736 vs. 0.6782537698745728) and 3 other locations (33.00%) - traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),)) - graph(%0 : Float(3, 4)) { - return (%0); - } - -We can fix this by modifying the code to not use the in-place update, but -rather build up the result tensor out-of-place with ``torch.cat``: - -.. testcode:: - - def fill_row_zero(x): - x = torch.cat((torch.rand(1, *x.shape[1:2]), x[1:2]), dim=0) - return x - - traced = torch.jit.trace(fill_row_zero, (torch.rand(3, 4),)) - print(traced.graph) - -.. testoutput:: - :hide: - - ... - -Frequently Asked Questions --------------------------- - -Q: I would like to train a model on GPU and do inference on CPU. What are the -best practices? - - First convert your model from GPU to CPU and then save it, like so: :: - - cpu_model = gpu_model.cpu() - sample_input_cpu = sample_input_gpu.cpu() - traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu) - torch.jit.save(traced_cpu, "cpu.pt") - - traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu) - torch.jit.save(traced_gpu, "gpu.pt") - - # ... later, when using the model: - - if use_gpu: - model = torch.jit.load("gpu.pt") - else: - model = torch.jit.load("cpu.pt") - - model(input) - - This is recommended because the tracer may witness tensor creation on a - specific device, so casting an already-loaded model may have unexpected - effects. Casting the model *before* saving it ensures that the tracer has - the correct device information. - - -Q: How do I store attributes on a :class:`ScriptModule`? - - Say we have a model like: - - .. testcode:: - - import torch - - class Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.x = 2 - - def forward(self): - return self.x - - m = torch.jit.script(Model()) - - - - If ``Model`` is instantiated it will result in a compilation error - since the compiler doesn't know about ``x``. There are 4 ways to inform the - compiler of attributes on :class:`ScriptModule`: - - 1. ``nn.Parameter`` - Values wrapped in ``nn.Parameter`` will work as they - do on ``nn.Module``\s - - 2. ``register_buffer`` - Values wrapped in ``register_buffer`` will work as - they do on ``nn.Module``\s. This is equivalent to an attribute (see 4) of type - ``Tensor``. - - 3. Constants - Annotating a class member as ``Final`` (or adding it to a list called - ``__constants__`` at the class definition level) will mark the contained names - as constants. Constants are saved directly in the code of the model. See - `builtin-constants` for details. - - 4. Attributes - Values that are a `supported type` can be added as mutable - attributes. Most types can be inferred but some may need to be specified, see - `module attributes` for details. - -Q: I would like to trace module's method but I keep getting this error: - -``RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient`` - - This error usually means that the method you are tracing uses a module's parameters and - you are passing the module's method instead of the module instance (e.g. ``my_module_instance.forward`` vs ``my_module_instance``). - - - Invoking ``trace`` with a module's method captures module parameters (which may require gradients) as **constants**. - - On the other hand, invoking ``trace`` with module's instance (e.g. ``my_module``) creates a new module and correctly copies parameters into the new module, so they can accumulate gradients if required. - - To trace a specific method on a module, see :func:`torch.jit.trace_module ` - -Known Issues ---------------- - -If you're using ``Sequential`` with TorchScript, the inputs of some -of the ``Sequential`` submodules may be falsely inferred to be -``Tensor``, even if they're annotated otherwise. The canonical -solution is to subclass ``nn.Sequential`` and redeclare ``forward`` -with the input typed correctly. - -Appendix --------- - -Migrating to PyTorch 1.2 Recursive Scripting API -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -This section details the changes to TorchScript in PyTorch 1.2. If you are new to TorchScript you can -skip this section. There are two main changes to the TorchScript API with PyTorch 1.2. - -1. :func:`torch.jit.script ` will now attempt to recursively compile functions, -methods, and classes that it encounters. Once you call ``torch.jit.script``, -compilation is "opt-out", rather than "opt-in". - -2. ``torch.jit.script(nn_module_instance)`` is now the preferred way to create -:class:`ScriptModule`\s, instead of inheriting from ``torch.jit.ScriptModule``. -These changes combine to provide a simpler, easier-to-use API for converting -your ``nn.Module``\s into :class:`ScriptModule`\s, ready to be optimized and executed in a -non-Python environment. - -The new usage looks like this: - -.. testcode:: - - import torch - import torch.nn as nn - import torch.nn.functional as F - - class Model(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(1, 20, 5) - self.conv2 = nn.Conv2d(20, 20, 5) - - def forward(self, x): - x = F.relu(self.conv1(x)) - return F.relu(self.conv2(x)) - - my_model = Model() - my_scripted_model = torch.jit.script(my_model) - - -* The module's ``forward`` is compiled by default. Methods called from ``forward`` are lazily compiled in the order they are used in ``forward``. -* To compile a method other than ``forward`` that is not called from ``forward``, add ``@torch.jit.export``. -* To stop the compiler from compiling a method, add :func:`@torch.jit.ignore ` or :func:`@torch.jit.unused `. ``@ignore`` leaves the -* method as a call to python, and ``@unused`` replaces it with an exception. ``@ignored`` cannot be exported; ``@unused`` can. -* Most attribute types can be inferred, so ``torch.jit.Attribute`` is not necessary. For empty container types, annotate their types using `PEP 526-style `_ class annotations. -* Constants can be marked with a ``Final`` class annotation instead of adding the name of the member to ``__constants__``. -* Python 3 type hints can be used in place of ``torch.jit.annotate`` - -As a result of these changes, the following items are considered deprecated and should not appear in new code: - * The ``@torch.jit.script_method`` decorator - * Classes that inherit from ``torch.jit.ScriptModule`` - * The ``torch.jit.Attribute`` wrapper class - * The ``__constants__`` array - * The ``torch.jit.annotate`` function - -Modules -^^^^^^^ -.. warning:: - - The :func:`@torch.jit.ignore ` annotation's behavior changes in - PyTorch 1.2. Before PyTorch 1.2 the @ignore decorator was used to make a function - or method callable from code that is exported. To get this functionality back, - use ``@torch.jit.unused()``. ``@torch.jit.ignore`` is now equivalent - to ``@torch.jit.ignore(drop=False)``. See :func:`@torch.jit.ignore ` - and :func:`@torch.jit.unused` for details. - -When passed to the :func:`torch.jit.script ` function, a ``torch.nn.Module``\'s data is -copied to a :class:`ScriptModule` and the TorchScript compiler compiles the module. -The module's ``forward`` is compiled by default. Methods called from ``forward`` are -lazily compiled in the order they are used in ``forward``, as well as any -``@torch.jit.export`` methods. - -.. autofunction:: export - -Functions -^^^^^^^^^ -Functions don't change much, they can be decorated with :func:`@torch.jit.ignore ` or :func:`torch.jit.unused ` if needed. - -.. testcode:: - - # Same behavior as pre-PyTorch 1.2 - @torch.jit.script - def some_fn(): - return 2 - - # Marks a function as ignored, if nothing - # ever calls it then this has no effect - @torch.jit.ignore - def some_fn2(): - return 2 - - # As with ignore, if nothing calls it then it has no effect. - # If it is called in script it is replaced with an exception. - @torch.jit.unused - def some_fn3(): - import pdb; pdb.set_trace() - return 4 - - # Doesn't do anything, this function is already - # the main entry point - @torch.jit.export - def some_fn4(): - return 2 - -TorchScript Classes -^^^^^^^^^^^^^^^^^^^ - -.. warning:: - - TorchScript class support is experimental. Currently it is best suited - for simple record-like types (think a ``NamedTuple`` with methods - attached). - -Everything in a user defined `TorchScript Class `_ is -exported by default, functions can be decorated with :func:`@torch.jit.ignore -` if needed. - -Attributes -^^^^^^^^^^ -The TorchScript compiler needs to know the types of `module attributes`. Most types -can be inferred from the value of the member. Empty lists and dicts cannot have their -types inferred and must have their types annotated with `PEP 526-style `_ class annotations. -If a type cannot be inferred and is not explicitly annotated, it will not be added as an attribute -to the resulting :class:`ScriptModule` - - -Old API: - -.. testcode:: - - from typing import Dict - import torch - - class MyModule(torch.jit.ScriptModule): - def __init__(self): - super().__init__() - self.my_dict = torch.jit.Attribute({}, Dict[str, int]) - self.my_int = torch.jit.Attribute(20, int) - - m = MyModule() - -New API: - -.. testcode:: - - from typing import Dict - - class MyModule(torch.nn.Module): - my_dict: Dict[str, int] - - def __init__(self): - super().__init__() - # This type cannot be inferred and must be specified - self.my_dict = {} - - # The attribute type here is inferred to be `int` - self.my_int = 20 - - def forward(self): - pass - - m = torch.jit.script(MyModule()) - - -Constants -^^^^^^^^^ -The ``Final`` type constructor can be used to mark members as `constant`. If members are not marked constant, they will be copied to the resulting :class:`ScriptModule` as an attribute. Using ``Final`` opens opportunities for optimization if the value is known to be fixed and gives additional type safety. - -Old API: - -.. testcode:: - - class MyModule(torch.jit.ScriptModule): - __constants__ = ['my_constant'] - - def __init__(self): - super().__init__() - self.my_constant = 2 - - def forward(self): - pass - m = MyModule() - -New API: - -:: - - from typing import Final - - class MyModule(torch.nn.Module): - - my_constant: Final[int] - - def __init__(self): - super().__init__() - self.my_constant = 2 - - def forward(self): - pass - - m = torch.jit.script(MyModule()) - -.. _Python 3 type hints: - -Variables -^^^^^^^^^ -Containers are assumed to have type ``Tensor`` and be non-optional (see -`Default Types` for more information). Previously, ``torch.jit.annotate`` was used to -tell the TorchScript compiler what the type should be. Python 3 style type hints are -now supported. - -.. testcode:: - - import torch - from typing import Dict, Optional - - @torch.jit.script - def make_dict(flag: bool): - x: Dict[str, int] = {} - x['hi'] = 2 - b: Optional[int] = None - if flag: - b = 2 - return x, b - -Fusion Backends -~~~~~~~~~~~~~~~ -There are a couple of fusion backends available to optimize TorchScript execution. The default fuser on CPUs is NNC, which can perform fusions for both CPUs and GPUs. The default fuser on GPUs is NVFuser, which supports a wider range of operators and has demonstrated generated kernels with improved throughput. See the `NVFuser documentation `_ for more details on usage and debugging. - - -References -~~~~~~~~~~ -.. toctree:: - :maxdepth: 1 - - jit_python_reference - jit_unsupported .. This package is missing doc. Adding it here for coverage .. This does not add anything to the rendered page. +.. py:module:: torch.jit.supported_ops +.. py:module:: torch.jit.unsupported_tensor_ops .. py:module:: torch.jit.mobile .. py:module:: torch.jit.annotations .. py:module:: torch.jit.frontend diff --git a/docs/source/jit_builtin_functions.rst b/docs/source/jit_builtin_functions.rst index a6cdb8c47870..6fd514f6e6fc 100644 --- a/docs/source/jit_builtin_functions.rst +++ b/docs/source/jit_builtin_functions.rst @@ -3,8 +3,6 @@ TorchScript Builtins ==================== -This is a full reference of functions and Tensor methods accessible in TorchScript - -.. contents:: :local: - -.. automodule:: torch.jit.supported_ops +.. warning:: + TorchScript is deprecated, please use + `torch.export `__ instead. diff --git a/docs/source/jit_language_reference.md b/docs/source/jit_language_reference.md index 973730948208..f2b31768e2d5 100644 --- a/docs/source/jit_language_reference.md +++ b/docs/source/jit_language_reference.md @@ -30,923 +30,7 @@ # TorchScript Language Reference -TorchScript is a statically typed subset of Python that can either be written directly (using -the {func}`@torch.jit.script ` decorator) or generated automatically from Python code via -tracing. When using tracing, code is automatically converted into this subset of -Python by recording only the actual operators on tensors and simply executing and -discarding the other surrounding Python code. - -When writing TorchScript directly using `@torch.jit.script` decorator, the programmer must -only use the subset of Python supported in TorchScript. This section documents -what is supported in TorchScript as if it were a language reference for a stand -alone language. Any features of Python not mentioned in this reference are not -part of TorchScript. See `Builtin Functions` for a complete reference of available -PyTorch tensor methods, modules, and functions. - -As a subset of Python, any valid TorchScript function is also a valid Python -function. This makes it possible to `disable TorchScript` and debug the -function using standard Python tools like `pdb`. The reverse is not true: there -are many valid Python programs that are not valid TorchScript programs. -Instead, TorchScript focuses specifically on the features of Python that are -needed to represent neural network models in PyTorch. - -(types)= - -(supported-type)= - -## Types - -The largest difference between TorchScript and the full Python language is that -TorchScript only supports a small set of types that are needed to express neural -net models. In particular, TorchScript supports: - -```{eval-rst} -.. csv-table:: - :header: "Type", "Description" - - "``Tensor``", "A PyTorch tensor of any dtype, dimension, or backend" - "``Tuple[T0, T1, ..., TN]``", "A tuple containing subtypes ``T0``, ``T1``, etc. (e.g. ``Tuple[Tensor, Tensor]``)" - "``bool``", "A boolean value" - "``int``", "A scalar integer" - "``float``", "A scalar floating point number" - "``str``", "A string" - "``List[T]``", "A list of which all members are type ``T``" - "``Optional[T]``", "A value which is either None or type ``T``" - "``Dict[K, V]``", "A dict with key type ``K`` and value type ``V``. Only ``str``, ``int``, and ``float`` are allowed as key types." - "``T``", "A {ref}`TorchScript Class`" - "``E``", "A {ref}`TorchScript Enum`" - "``NamedTuple[T0, T1, ...]``", "A :func:`collections.namedtuple ` tuple type" - "``Union[T0, T1, ...]``", "One of the subtypes ``T0``, ``T1``, etc." -``` - -Unlike Python, each variable in TorchScript function must have a single static type. -This makes it easier to optimize TorchScript functions. - -Example (a type mismatch) - -```{eval-rst} -.. testcode:: - - import torch - - @torch.jit.script - def an_error(x): - if x: - r = torch.rand(1) - else: - r = 4 - return r - -``` - -```{eval-rst} -.. testoutput:: - - Traceback (most recent call last): - ... - RuntimeError: ... - - Type mismatch: r is set to type Tensor in the true branch and type int in the false branch: - @torch.jit.script - def an_error(x): - if x: - ~~~~~ - r = torch.rand(1) - ~~~~~~~~~~~~~~~~~ - else: - ~~~~~ - r = 4 - ~~~~~ <--- HERE - return r - and was used here: - else: - r = 4 - return r - ~ <--- HERE... -``` - -### Unsupported Typing Constructs - -TorchScript does not support all features and types of the {mod}`typing` module. Some of these -are more fundamental things that are unlikely to be added in the future while others -may be added if there is enough user demand to make it a priority. - -These types and features from the {mod}`typing` module are unavailable in TorchScript. - -```{eval-rst} -.. csv-table:: - :header: "Item", "Description" - - ":any:`typing.Any`", ":any:`typing.Any` is currently in development but not yet released" - ":any:`typing.NoReturn`", "Not implemented" - ":any:`typing.Sequence`", "Not implemented" - ":any:`typing.Callable`", "Not implemented" - ":any:`typing.Literal`", "Not implemented" - ":any:`typing.ClassVar`", "Not implemented" - ":any:`typing.Final`", "This is supported for :any:`module attributes ` class attribute annotations but not for functions" - ":any:`typing.AnyStr`", "TorchScript does not support :any:`bytes` so this type is not used" - ":any:`typing.overload`", ":any:`typing.overload` is currently in development but not yet released" - "Type aliases", "Not implemented" - "Nominal vs structural subtyping", "Nominal typing is in development, but structural typing is not" - "NewType", "Unlikely to be implemented" - "Generics", "Unlikely to be implemented" -``` - -Any other functionality from the {any}`typing` module not explicitly listed in this documentation is unsupported. - -### Default Types - -By default, all parameters to a TorchScript function are assumed to be Tensor. -To specify that an argument to a TorchScript function is another type, it is possible to use -MyPy-style type annotations using the types listed above. - -```{eval-rst} -.. testcode:: - - import torch - - @torch.jit.script - def foo(x, tup): - # type: (int, Tuple[Tensor, Tensor]) -> Tensor - t0, t1 = tup - return t0 + t1 + x - - print(foo(3, (torch.rand(3), torch.rand(3)))) -``` - -```{eval-rst} -.. testoutput:: - :hide: - - ... -``` - -:::{note} -It is also possible to annotate types with Python 3 type hints from the -`typing` module. - -```{eval-rst} -.. testcode:: - - import torch - from typing import Tuple - - @torch.jit.script - def foo(x: int, tup: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: - t0, t1 = tup - return t0 + t1 + x - - print(foo(3, (torch.rand(3), torch.rand(3)))) -``` - -```{eval-rst} -.. testoutput:: - :hide: - - ... -``` -::: - -An empty list is assumed to be `List[Tensor]` and empty dicts -`Dict[str, Tensor]`. To instantiate an empty list or dict of other types, -use `Python 3 type hints`. - -Example (type annotations for Python 3): - -```{eval-rst} -.. testcode:: - - import torch - import torch.nn as nn - from typing import Dict, List, Tuple - - class EmptyDataStructures(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x: torch.Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]: - # This annotates the list to be a `List[Tuple[int, float]]` - my_list: List[Tuple[int, float]] = [] - for i in range(10): - my_list.append((i, x.item())) - - my_dict: Dict[str, int] = {} - return my_list, my_dict - - x = torch.jit.script(EmptyDataStructures()) - - - -``` - -### Optional Type Refinement - -TorchScript will refine the type of a variable of type `Optional[T]` when -a comparison to `None` is made inside the conditional of an if-statement or checked in an `assert`. -The compiler can reason about multiple `None` checks that are combined with -`and`, `or`, and `not`. Refinement will also occur for else blocks of if-statements -that are not explicitly written. - -The `None` check must be within the if-statement's condition; assigning -a `None` check to a variable and using it in the if-statement's condition will -not refine the types of variables in the check. -Only local variables will be refined, an attribute like `self.x` will not and must assigned to -a local variable to be refined. - -Example (refining types on parameters and locals): - -```{eval-rst} -.. testcode:: - - import torch - import torch.nn as nn - from typing import Optional - - class M(nn.Module): - z: Optional[int] - - def __init__(self, z): - super().__init__() - # If `z` is None, its type cannot be inferred, so it must - # be specified (above) - self.z = z - - def forward(self, x, y, z): - # type: (Optional[int], Optional[int], Optional[int]) -> int - if x is None: - x = 1 - x = x + 1 - - # Refinement for an attribute by assigning it to a local - z = self.z - if y is not None and z is not None: - x = y + z - - # Refinement via an `assert` - assert z is not None - x += z - return x - - module = torch.jit.script(M(2)) - module = torch.jit.script(M(None)) - -``` - -(TorchScript Class)= - -(TorchScript Classes)= - -(torchscript-classes)= - -### TorchScript Classes - :::{warning} -TorchScript class support is experimental. Currently it is best suited -for simple record-like types (think a `NamedTuple` with methods -attached). -::: - -Python classes can be used in TorchScript if they are annotated with {func}`@torch.jit.script `, -similar to how you would declare a TorchScript function: - -```{eval-rst} -.. testcode:: - :skipif: True # TODO: fix the source file resolving so this can be tested - - @torch.jit.script - class Foo: - def __init__(self, x, y): - self.x = x - - def aug_add_x(self, inc): - self.x += inc - -``` - -This subset is restricted: - -- All functions must be valid TorchScript functions (including `__init__()`). - -- Classes must be new-style classes, as we use `__new__()` to construct them with pybind11. - -- TorchScript classes are statically typed. Members can only be declared by assigning to - self in the `__init__()` method. - - > For example, assigning to `self` outside of the `__init__()` method: - > - > ``` - > @torch.jit.script - > class Foo: - > def assign_x(self): - > self.x = torch.rand(2, 3) - > ``` - > - > Will result in: - > - > ``` - > RuntimeError: - > Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: - > def assign_x(self): - > self.x = torch.rand(2, 3) - > ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE - > ``` - -- No expressions except method definitions are allowed in the body of the class. - -- No support for inheritance or any other polymorphism strategy, except for inheriting - from `object` to specify a new-style class. - -After a class is defined, it can be used in both TorchScript and Python interchangeably -like any other TorchScript type: - -``` -# Declare a TorchScript class -@torch.jit.script -class Pair: - def __init__(self, first, second): - self.first = first - self.second = second - -@torch.jit.script -def sum_pair(p): - # type: (Pair) -> Tensor - return p.first + p.second - -p = Pair(torch.rand(2, 3), torch.rand(2, 3)) -print(sum_pair(p)) -``` - -(TorchScript Enum)= - -(TorchScript Enums)= - -(torchscript-enums)= - -### TorchScript Enums - -Python enums can be used in TorchScript without any extra annotation or code: - -``` -from enum import Enum - - -class Color(Enum): - RED = 1 - GREEN = 2 - -@torch.jit.script -def enum_fn(x: Color, y: Color) -> bool: - if x == Color.RED: - return True - - return x == y -``` - -After an enum is defined, it can be used in both TorchScript and Python interchangeably -like any other TorchScript type. The type of the values of an enum must be `int`, -`float`, or `str`. All values must be of the same type; heterogeneous types for enum -values are not supported. - -### Named Tuples - -Types produced by {func}`collections.namedtuple ` can be used in TorchScript. - -```{eval-rst} -.. testcode:: - - import torch - import collections - - Point = collections.namedtuple('Point', ['x', 'y']) - - @torch.jit.script - def total(point): - # type: (Point) -> Tensor - return point.x + point.y - - p = Point(x=torch.rand(3), y=torch.rand(3)) - print(total(p)) -``` - -```{eval-rst} -.. testoutput:: - :hide: - - ... - -``` - -(jit_iterables)= - -### Iterables - -Some functions (for example, {any}`zip` and {any}`enumerate`) can only operate on iterable types. -Iterable types in TorchScript include `Tensor`s, lists, tuples, dictionaries, strings, -{any}`torch.nn.ModuleList` and {any}`torch.nn.ModuleDict`. - -## Expressions - -The following Python Expressions are supported. - -### Literals - -``` -True -False -None -'string literals' -"string literals" -3 # interpreted as int -3.4 # interpreted as a float -``` - -#### List Construction - -An empty list is assumed have type `List[Tensor]`. -The types of other list literals are derived from the type of the members. -See [Default Types] for more details. - -``` -[3, 4] -[] -[torch.rand(3), torch.rand(4)] -``` - -#### Tuple Construction - -``` -(3, 4) -(3,) -``` - -#### Dict Construction - -An empty dict is assumed have type `Dict[str, Tensor]`. -The types of other dict literals are derived from the type of the members. -See [Default Types] for more details. - -``` -{'hello': 3} -{} -{'a': torch.rand(3), 'b': torch.rand(4)} -``` - -### Variables - -See [Variable Resolution] for how variables are resolved. - -``` -my_variable_name -``` - -### Arithmetic Operators - -``` -a + b -a - b -a * b -a / b -a ^ b -a @ b -``` - -### Comparison Operators - -``` -a == b -a != b -a < b -a > b -a <= b -a >= b -``` - -### Logical Operators - -``` -a and b -a or b -not b -``` - -### Subscripts and Slicing - -``` -t[0] -t[-1] -t[0:2] -t[1:] -t[:1] -t[:] -t[0, 1] -t[0, 1:2] -t[0, :1] -t[-1, 1:, 0] -t[1:, -1, 0] -t[i:j, i] -``` - -### Function Calls - -Calls to `builtin functions` - -``` -torch.rand(3, dtype=torch.int) -``` - -Calls to other script functions: - -```{eval-rst} -.. testcode:: - - import torch - - @torch.jit.script - def foo(x): - return x + 1 - - @torch.jit.script - def bar(x): - return foo(x) -``` - -### Method Calls - -Calls to methods of builtin types like tensor: `x.mm(y)` - -On modules, methods must be compiled before they can be called. The TorchScript -compiler recursively compiles methods it sees when compiling other methods. By default, -compilation starts on the `forward` method. Any methods called by `forward` will -be compiled, and any methods called by those methods, and so on. To start compilation at -a method other than `forward`, use the {func}`@torch.jit.export ` decorator -(`forward` implicitly is marked `@torch.jit.export`). - -Calling a submodule directly (e.g. `self.resnet(input)`) is equivalent to -calling its `forward` method (e.g. `self.resnet.forward(input)`). - -```{eval-rst} -.. testcode:: - :skipif: torchvision is None - - import torch - import torch.nn as nn - import torchvision - - class MyModule(nn.Module): - def __init__(self): - super().__init__() - means = torch.tensor([103.939, 116.779, 123.68]) - self.means = torch.nn.Parameter(means.resize_(1, 3, 1, 1)) - resnet = torchvision.models.resnet18() - self.resnet = torch.jit.trace(resnet, torch.rand(1, 3, 224, 224)) - - def helper(self, input): - return self.resnet(input - self.means) - - def forward(self, input): - return self.helper(input) - - # Since nothing in the model calls `top_level_method`, the compiler - # must be explicitly told to compile this method - @torch.jit.export - def top_level_method(self, input): - return self.other_helper(input) - - def other_helper(self, input): - return input + 10 - - # `my_script_module` will have the compiled methods `forward`, `helper`, - # `top_level_method`, and `other_helper` - my_script_module = torch.jit.script(MyModule()) - -``` - -### Ternary Expressions - -``` -x if x > y else y -``` - -### Casts - -``` -float(ten) -int(3.5) -bool(ten) -str(2)`` -``` - -### Accessing Module Parameters - -``` -self.my_parameter -self.my_submodule.my_parameter -``` - -## Statements - -TorchScript supports the following types of statements: - -### Simple Assignments - -``` -a = b -a += b # short-hand for a = a + b, does not operate in-place on a -a -= b -``` - -### Pattern Matching Assignments - -``` -a, b = tuple_or_list -a, b, *c = a_tuple -``` - -Multiple Assignments - -``` -a = b, c = tup -``` - -### Print Statements - -``` -print("the result of an add:", a + b) -``` - -### If Statements - -``` -if a < 4: - r = -a -elif a < 3: - r = a + a -else: - r = 3 * a -``` - -In addition to bools, floats, ints, and Tensors can be used in a conditional -and will be implicitly casted to a boolean. - -### While Loops - -``` -a = 0 -while a < 4: - print(a) - a += 1 -``` - -### For loops with range - -``` -x = 0 -for i in range(10): - x *= i -``` - -### For loops over tuples - -These unroll the loop, generating a body for -each member of the tuple. The body must type-check correctly for each member. - -``` -tup = (3, torch.rand(4)) -for x in tup: - print(x) -``` - -### For loops over constant nn.ModuleList - -To use a `nn.ModuleList` inside a compiled method, it must be marked -constant by adding the name of the attribute to the `__constants__` -list for the type. For loops over a `nn.ModuleList` will unroll the body of the -loop at compile time, with each member of the constant module list. - -```{eval-rst} -.. testcode:: - - class SubModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.weight = nn.Parameter(torch.randn(2)) - - def forward(self, input): - return self.weight + input - - class MyModule(torch.nn.Module): - __constants__ = ['mods'] - - def __init__(self): - super().__init__() - self.mods = torch.nn.ModuleList([SubModule() for i in range(10)]) - - def forward(self, v): - for module in self.mods: - v = module(v) - return v - - - m = torch.jit.script(MyModule()) - - -``` - -### Break and Continue - -``` -for i in range(5): - if i == 1: - continue - if i == 3: - break - print(i) -``` - -### Return - -``` -return a, b -``` - -## Variable Resolution - -TorchScript supports a subset of Python's variable resolution (i.e. scoping) -rules. Local variables behave the same as in Python, except for the restriction -that a variable must have the same type along all paths through a function. -If a variable has a different type on different branches of an if statement, it -is an error to use it after the end of the if statement. - -Similarly, a variable is not allowed to be used if it is only *defined* along some -paths through the function. - -Example: - -```{eval-rst} -.. testcode:: - - @torch.jit.script - def foo(x): - if x < 0: - y = 4 - print(y) -``` - -```{eval-rst} -.. testoutput:: - - Traceback (most recent call last): - ... - RuntimeError: ... - - y is not defined in the false branch... - @torch.jit.script... - def foo(x): - if x < 0: - ~~~~~~~~~ - y = 4 - ~~~~~ <--- HERE - print(y) - and was used here: - if x < 0: - y = 4 - print(y) - ~ <--- HERE... -``` - -Non-local variables are resolved to Python values at compile time when the -function is defined. These values are then converted into TorchScript values using -the rules described in [Use of Python Values]. - -## Use of Python Values - -To make writing TorchScript more convenient, we allow script code to refer -to Python values in the surrounding scope. For instance, any time there is a -reference to `torch`, the TorchScript compiler is actually resolving it to the -`torch` Python module when the function is declared. These Python values are -not a first class part of TorchScript. Instead they are de-sugared at compile-time -into the primitive types that TorchScript supports. This depends -on the dynamic type of the Python valued referenced when compilation occurs. -This section describes the rules that are used when accessing Python values in TorchScript. - -### Functions - -TorchScript can call Python functions. This functionality is very useful when -incrementally converting a model to TorchScript. The model can be moved function-by-function -to TorchScript, leaving calls to Python functions in place. This way you can incrementally -check the correctness of the model as you go. - -```{eval-rst} -.. autofunction:: torch.jit.is_scripting -``` - -```{eval-rst} -.. autofunction:: torch.jit.is_tracing - -``` - -### Attribute Lookup On Python Modules - -TorchScript can lookup attributes on modules. `Builtin functions` like `torch.add` -are accessed this way. This allows TorchScript to call functions defined in -other modules. - -(constant)= - -### Python-defined Constants - -TorchScript also provides a way to use constants that are defined in Python. -These can be used to hard-code hyper-parameters into the function, or to -define universal constants. There are two ways of specifying that a Python -value should be treated as a constant. - -1. Values looked up as attributes of a module are assumed to be constant: - -```{eval-rst} -.. testcode:: - - import math - import torch - - @torch.jit.script - def fn(): - return math.pi -``` - -2. Attributes of a ScriptModule can be marked constant by annotating them with `Final[T]` - -``` -import torch -import torch.nn as nn - -class Foo(nn.Module): - # `Final` from the `typing_extensions` module can also be used - a : torch.jit.Final[int] - - def __init__(self): - super().__init__() - self.a = 1 + 4 - - def forward(self, input): - return self.a + input - -f = torch.jit.script(Foo()) -``` - -Supported constant Python types are - -- `int` -- `float` -- `bool` -- `torch.device` -- `torch.layout` -- `torch.dtype` -- tuples containing supported types -- `torch.nn.ModuleList` which can be used in a TorchScript for loop - -(module-attributes)= -(Module Attributes)= - -### Module Attributes - -The `torch.nn.Parameter` wrapper and `register_buffer` can be used to assign -tensors to a module. Other values assigned to a module that is compiled -will be added to the compiled module if their types can be inferred. All [types] -available in TorchScript can be used as module attributes. Tensor attributes are -semantically the same as buffers. The type of empty lists and dictionaries and `None` -values cannot be inferred and must be specified via -[PEP 526-style](https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations) class annotations. -If a type cannot be inferred and is not explicitly annotated, it will not be added as an attribute -to the resulting {class}`ScriptModule`. - -Example: - -```{eval-rst} -.. testcode:: - - from typing import List, Dict - - class Foo(nn.Module): - # `words` is initialized as an empty list, so its type must be specified - words: List[str] - - # The type could potentially be inferred if `a_dict` (below) was not - # empty, but this annotation ensures `some_dict` will be made into the - # proper type - some_dict: Dict[str, int] - - def __init__(self, a_dict): - super().__init__() - self.words = [] - self.some_dict = a_dict - - # `int`s can be inferred - self.my_int = 10 - - def forward(self, input): - # type: (str) -> int - self.words.append(input) - return self.some_dict[input] + self.my_int - - f = torch.jit.script(Foo({'hi': 2})) -``` +TorchScript is deprecated, please use +[torch.export](https://docs.pytorch.org/docs/stable/export.html) instead. +::: \ No newline at end of file diff --git a/docs/source/jit_language_reference_v2.md b/docs/source/jit_language_reference_v2.md index 12bd2a18a201..40da0740963b 100644 --- a/docs/source/jit_language_reference_v2.md +++ b/docs/source/jit_language_reference_v2.md @@ -25,1830 +25,7 @@ # TorchScript Language Reference -This reference manual describes the syntax and core semantics of the TorchScript language. -TorchScript is a statically typed subset of the Python language. This document explains the supported features of -Python in TorchScript and also how the language diverges from regular Python. Any features of Python that are not mentioned in -this reference manual are not part of TorchScript. TorchScript focuses specifically on the features of Python that are needed to -represent neural network models in PyTorch. - -```{contents} -:depth: 1 -:local: true -``` - -(type-system)= - -## Terminology - -This document uses the following terminologies: - -```{eval-rst} -.. list-table:: - :widths: 25 25 - :header-rows: 1 - - * - Pattern - - Notes - * - ``::=`` - - Indicates that the given symbol is defined as. - * - ``" "`` - - Represents real keywords and delimiters that are part of the syntax. - * - ``A | B`` - - Indicates either A or B. - * - ``( )`` - - Indicates grouping. - * - ``[]`` - - Indicates optional. - * - ``A+`` - - Indicates a regular expression where term A is repeated at least once. - * - ``A*`` - - Indicates a regular expression where term A is repeated zero or more times. -``` - -## Type System - -TorchScript is a statically typed subset of Python. The largest difference between TorchScript and the full Python language is that TorchScript only supports a small set of types that are needed to express -neural net models. - -### TorchScript Types - -The TorchScript type system consists of `TSType` and `TSModuleType` as defined below. - -``` -TSAllType ::= TSType | TSModuleType -TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType -``` - -`TSType` represents the majority of TorchScript types that are composable and that can be used in TorchScript type annotations. -`TSType` refers to any of the following: - -- Meta Types, e.g., `Any` -- Primitive Types, e.g., `int`, `float`, and `str` -- Structural Types, e.g., `Optional[int]` or `List[MyClass]` -- Nominal Types (Python classes), e.g., `MyClass` (user-defined), `torch.tensor` (built-in) - -`TSModuleType` represents `torch.nn.Module` and its subclasses. It is treated differently from `TSType` because its type schema is inferred partly from the object instance and partly from the class definition. -As such, instances of a `TSModuleType` may not follow the same static type schema. `TSModuleType` cannot be used as a TorchScript type annotation or be composed with `TSType` for type safety considerations. - -### Meta Types - -Meta types are so abstract that they are more like type constraints than concrete types. -Currently TorchScript defines one meta-type, `Any`, that represents any TorchScript type. - -#### `Any` Type - -The `Any` type represents any TorchScript type. `Any` specifies no type constraints, thus there is no type-checking on `Any`. -As such it can be bound to any Python or TorchScript data types (e.g., `int`, TorchScript `tuple`, or an arbitrary Python class that is not scripted). - -``` -TSMetaType ::= "Any" -``` - -Where: - -- `Any` is the Python class name from the typing module. Therefore, to use the `Any` type, you must import it from `typing` (e.g., `from typing import Any`). -- Since `Any` can represent any TorchScript type, the set of operators that are allowed to operate on values of this type on `Any` is limited. - -#### Operators Supported for `Any` Type - -- Assignment to data of `Any` type. -- Binding to parameter or return of `Any` type. -- `x is`, `x is not` where `x` is of `Any` type. -- `isinstance(x, Type)` where `x` is of `Any` type. -- Data of `Any` type is printable. -- Data of `List[Any]` type may be sortable if the data is a list of values of the same type `T` and that `T` supports comparison operators. - -**Compared to Python** - -`Any` is the least constrained type in the TorchScript type system. In that sense, it is quite similar to the -`Object` class in Python. However, `Any` only supports a subset of the operators and methods that are supported by `Object`. - -#### Design Notes - -When we script a PyTorch module, we may encounter data that is not involved in the execution of the script. Nevertheless, it has to be described -by a type schema. It is not only cumbersome to describe static types for unused data (in the context of the script), but also may lead to unnecessary -scripting failures. `Any` is introduced to describe the type of the data where precise static types are not necessary for compilation. - -**Example 1** - -This example illustrates how `Any` can be used to allow the second element of the tuple parameter to be of any type. This is possible -because `x[1]` is not involved in any computation that requires knowing its precise type. - -```{eval-rst} -.. testcode:: - - import torch - - from typing import Tuple - from typing import Any - - @torch.jit.export - def inc_first_element(x: Tuple[int, Any]): - return (x[0]+1, x[1]) - - m = torch.jit.script(inc_first_element) - print(m((1,2.0))) - print(m((1,(100,200)))) -``` - -The example above produces the following output: - -```{eval-rst} -.. testoutput:: - - (2, 2.0) - (2, (100, 200)) -``` - -The second element of the tuple is of `Any` type, thus can bind to multiple types. -For example, `(1, 2.0)` binds a float type to `Any` as in `Tuple[int, Any]`, -whereas `(1, (100, 200))` binds a tuple to `Any` in the second invocation. - -**Example 2** - -This example illustrates how we can use `isinstance` to dynamically check the type of the data that is annotated as `Any` type: - -```{eval-rst} -.. testcode:: - - import torch - from typing import Any - - def f(a:Any): - print(a) - return (isinstance(a, torch.Tensor)) - - ones = torch.ones([2]) - m = torch.jit.script(f) - print(m(ones)) -``` - -The example above produces the following output: - -```{eval-rst} -.. testoutput:: - - 1 - 1 - [ CPUFloatType{2} ] - True -``` - -### Primitive Types - -Primitive TorchScript types are types that represent a single type of value and go with a single pre-defined -type name. - -``` -TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None" -``` - -### Structural Types - -Structural types are types that are structurally defined without a user-defined name (unlike nominal types), -such as `Future[int]`. Structural types are composable with any `TSType`. - -``` -TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | - TSOptional | TSUnion | TSFuture | TSRRef | TSAwait - -TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]" -TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")" -TSList ::= "List" "[" TSType "]" -TSOptional ::= "Optional" "[" TSType "]" -TSUnion ::= "Union" "[" (TSType ",")* TSType "]" -TSFuture ::= "Future" "[" TSType "]" -TSRRef ::= "RRef" "[" TSType "]" -TSAwait ::= "Await" "[" TSType "]" -TSDict ::= "Dict" "[" KeyType "," TSType "]" -KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any" -``` - -Where: - -- `Tuple`, `List`, `Optional`, `Union`, `Future`, `Dict` represent Python type class names that are defined in the module `typing`. To use these type names, you must import them from `typing` (e.g., `from typing import Tuple`). -- `namedtuple` represents the Python class `collections.namedtuple` or `typing.NamedTuple`. -- `Future` and `RRef` represent the Python classes `torch.futures` and `torch.distributed.rpc`. -- `Await` represent the Python class `torch._awaits._Await` - -**Compared to Python** - -Apart from being composable with TorchScript types, these TorchScript structural types often support a common subset of the operators and methods of their Python counterparts. - -**Example 1** - -This example uses `typing.NamedTuple` syntax to define a tuple: - -```{eval-rst} -.. testcode:: - - import torch - from typing import NamedTuple - from typing import Tuple - - class MyTuple(NamedTuple): - first: int - second: int - - def inc(x: MyTuple) -> Tuple[int, int]: - return (x.first+1, x.second+1) - - t = MyTuple(first=1, second=2) - scripted_inc = torch.jit.script(inc) - print("TorchScript:", scripted_inc(t)) -``` - -The example above produces the following output: - -```{eval-rst} -.. testoutput:: - - TorchScript: (2, 3) -``` - -**Example 2** - -This example uses `collections.namedtuple` syntax to define a tuple: - -```{eval-rst} -.. testcode:: - - import torch - from typing import NamedTuple - from typing import Tuple - from collections import namedtuple - - _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('first', int), ('second', int)]) - _UnannotatedNamedTuple = namedtuple('_NamedTupleAnnotated', ['first', 'second']) - - def inc(x: _AnnotatedNamedTuple) -> Tuple[int, int]: - return (x.first+1, x.second+1) - - m = torch.jit.script(inc) - print(inc(_UnannotatedNamedTuple(1,2))) -``` - -The example above produces the following output: - -```{eval-rst} -.. testoutput:: - - (2, 3) -``` - -**Example 3** - -This example illustrates a common mistake of annotating structural types, i.e., not importing the composite type -classes from the `typing` module: - -```python -import torch - -# ERROR: Tuple not recognized because not imported from typing -@torch.jit.export -def inc(x: Tuple[int, int]): - return (x[0]+1, x[1]+1) - -m = torch.jit.script(inc) -print(m((1,2))) -``` - -Running the above code yields the following scripting error: - -```python -File "test-tuple.py", line 5, in - def inc(x: Tuple[int, int]): -NameError: name 'Tuple' is not defined -``` - -The remedy is to add the line `from typing import Tuple` to the beginning of the code. - -### Nominal Types - -Nominal TorchScript types are Python classes. These types are called nominal because they are declared with a custom -name and are compared using class names. Nominal classes are further classified into the following categories: - -``` -TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum -``` - -Among them, `TSCustomClass` and `TSEnum` must be compilable to TorchScript Intermediate Representation (IR). This is enforced by the type-checker. - -### Built-in Class - -Built-in nominal types are Python classes whose semantics are built into the TorchScript system (e.g., tensor types). -TorchScript defines the semantics of these built-in nominal types, and often supports only a subset of the methods or -attributes of its Python class definition. - -``` -TSBuiltinClass ::= TSTensor | "torch.device" | "torch.Stream" | "torch.dtype" | - "torch.nn.ModuleList" | "torch.nn.ModuleDict" | ... -TSTensor ::= "torch.Tensor" | "common.SubTensor" | "common.SubWithTorchFunction" | - "torch.nn.parameter.Parameter" | and subclasses of torch.Tensor -``` - -#### Special Note on torch.nn.ModuleList and torch.nn.ModuleDict - -Although `torch.nn.ModuleList` and `torch.nn.ModuleDict` are defined as a list and dictionary in Python, -they behave more like tuples in TorchScript: - -- In TorchScript, instances of `torch.nn.ModuleList` or `torch.nn.ModuleDict` are immutable. -- Code that iterates over `torch.nn.ModuleList` or `torch.nn.ModuleDict` is completely unrolled so that elements of `torch.nn.ModuleList` or keys of `torch.nn.ModuleDict` can be of different subclasses of `torch.nn.Module`. - -**Example** - -The following example highlights the use of a few built-in Torchscript classes (`torch.*`): - -```python -import torch - -@torch.jit.script -class A: - def __init__(self): - self.x = torch.rand(3) - - def f(self, y: torch.device): - return self.x.to(device=y) - -def g(): - a = A() - return a.f(torch.device("cpu")) - -script_g = torch.jit.script(g) -print(script_g.graph) -``` - -### Custom Class - -Unlike built-in classes, semantics of custom classes are user-defined and the entire class definition must be compilable to TorchScript IR and subject to TorchScript type-checking rules. - -``` -TSClassDef ::= [ "@torch.jit.script" ] - "class" ClassName [ "(object)" ] ":" - MethodDefinition | - [ "@torch.jit.ignore" ] | [ "@torch.jit.unused" ] - MethodDefinition -``` - -Where: - -- Classes must be new-style classes. Python 3 supports only new-style classes. In Python 2.x, a new-style class is specified by subclassing from the object. -- Instance data attributes are statically typed, and instance attributes must be declared by assignments inside the `__init__()` method. -- Method overloading is not supported (i.e., you cannot have multiple methods with the same method name). -- `MethodDefinition` must be compilable to TorchScript IR and adhere to TorchScript’s type-checking rules, (i.e., all methods must be valid TorchScript functions and class attribute definitions must be valid TorchScript statements). -- `torch.jit.ignore` and `torch.jit.unused` can be used to ignore the method or function that is not fully torchscriptable or should be ignored by the compiler. - -**Compared to Python** - -TorchScript custom classes are quite limited compared to their Python counterpart. Torchscript custom classes: - -- Do not support class attributes. -- Do not support subclassing except for subclassing an interface type or object. -- Do not support method overloading. -- Must initialize all its instance attributes in `__init__()`; this is because TorchScript constructs a static schema of the class by inferring attribute types in `__init__()`. -- Must contain only methods that satisfy TorchScript type-checking rules and are compilable to TorchScript IRs. - -**Example 1** - -Python classes can be used in TorchScript if they are annotated with `@torch.jit.script`, similar to how a TorchScript function would be declared: - -```python -@torch.jit.script -class MyClass: - def __init__(self, x: int): - self.x = x - - def inc(self, val: int): - self.x += val -``` - -**Example 2** - -A TorchScript custom class type must "declare" all its instance attributes by assignments in `__init__()`. If an instance attribute is not defined in `__init__()` but accessed in other methods of the class, the class cannot be compiled as a TorchScript class, as shown in the following example: - -```python -import torch - -@torch.jit.script -class foo: - def __init__(self): - self.y = 1 - -# ERROR: self.x is not defined in __init__ -def assign_x(self): - self.x = torch.rand(2, 3) -``` - -The class will fail to compile and issue the following error: - -``` -RuntimeError: -Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: -def assign_x(self): - self.x = torch.rand(2, 3) - ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE -``` - -**Example 3** - -In this example, a TorchScript custom class defines a class variable name, which is not allowed: - -```python -import torch - -@torch.jit.script -class MyClass(object): - name = "MyClass" - def __init__(self, x: int): - self.x = x - -def fn(a: MyClass): - return a.name -``` - -It leads to the following compile-time error: - -``` -RuntimeError: -'__torch__.MyClass' object has no attribute or method 'name'. Did you forget to initialize an attribute in __init__()?: - File "test-class2.py", line 10 -def fn(a: MyClass): - return a.name - ~~~~~~ <--- HERE -``` - -### Enum Type - -Like custom classes, semantics of the enum type are user-defined and the entire class definition must be compilable to TorchScript IR and adhere to TorchScript type-checking rules. - -``` -TSEnumDef ::= "class" Identifier "(enum.Enum | TSEnumType)" ":" - ( MemberIdentifier "=" Value )+ - ( MethodDefinition )* -``` - -Where: - -- Value must be a TorchScript literal of type `int`, `float`, or `str`, and must be of the same TorchScript type. -- `TSEnumType` is the name of a TorchScript enumerated type. Similar to Python enum, TorchScript allows restricted `Enum` subclassing, that is, subclassing an enumerated is allowed only if it does not define any members. - -**Compared to Python** - -- TorchScript supports only `enum.Enum`. It does not support other variations such as `enum.IntEnum`, `enum.Flag`, `enum.IntFlag`, and `enum.auto`. -- Values of TorchScript enum members must be of the same type and can only be `int`, `float`, or `str` types, whereas Python enum members can be of any type. -- Enums containing methods are ignored in TorchScript. - -**Example 1** - -The following example defines the class `Color` as an `Enum` type: - -```python -import torch -from enum import Enum - -class Color(Enum): - RED = 1 - GREEN = 2 - -def enum_fn(x: Color, y: Color) -> bool: - if x == Color.RED: - return True - return x == y - -m = torch.jit.script(enum_fn) - -print("Eager: ", enum_fn(Color.RED, Color.GREEN)) -print("TorchScript: ", m(Color.RED, Color.GREEN)) -``` - -**Example 2** - -The following example shows the case of restricted enum subclassing, where `BaseColor` does not define any member, thus can be subclassed by `Color`: - -```python -import torch -from enum import Enum - -class BaseColor(Enum): - def foo(self): - pass - -class Color(BaseColor): - RED = 1 - GREEN = 2 - -def enum_fn(x: Color, y: Color) -> bool: - if x == Color.RED: - return True - return x == y - -m = torch.jit.script(enum_fn) - -print("TorchScript: ", m(Color.RED, Color.GREEN)) -print("Eager: ", enum_fn(Color.RED, Color.GREEN)) -``` - -### TorchScript Module Class - -`TSModuleType` is a special class type that is inferred from object instances that are created outside TorchScript. `TSModuleType` is named by the Python class of the object instance. The `__init__()` method of the Python class is not considered a TorchScript method, so it does not have to comply with TorchScript’s type-checking rules. - -The type schema of a module instance class is constructed directly from an instance object (created outside the scope of TorchScript) rather than inferred from `__init__()` like custom classes. It is possible that two objects of the same instance class type follow two different type schemas. - -In this sense, `TSModuleType` is not really a static type. Therefore, for type safety considerations, `TSModuleType` cannot be used in a TorchScript type annotation or be composed with `TSType`. - -### Module Instance Class - -TorchScript module type represents the type schema of a user-defined PyTorch module instance. When scripting a PyTorch module, the module object is always created outside TorchScript (i.e., passed in as parameter to `forward`). The Python module class is treated as a module instance class, so the `__init__()` method of the Python module class is not subject to the type-checking rules of TorchScript. - -``` -TSModuleType ::= "class" Identifier "(torch.nn.Module)" ":" - ClassBodyDefinition -``` - -Where: - -- `forward()` and other methods decorated with `@torch.jit.export` must be compilable to TorchScript IR and subject to TorchScript’s type-checking rules. - -Unlike custom classes, only the forward method and other methods decorated with `@torch.jit.export` of the module type need to be compilable. Most notably, `__init__()` is not considered a TorchScript method. Consequently, module type constructors cannot be invoked within the scope of TorchScript. Instead, TorchScript module objects are always constructed outside and passed into `torch.jit.script(ModuleObj)`. - -**Example 1** - -This example illustrates a few features of module types: - -- The `TestModule` instance is created outside the scope of TorchScript (i.e., before invoking `torch.jit.script`). -- `__init__()` is not considered a TorchScript method, therefore, it does not have to be annotated and can contain arbitrary Python code. In addition, the `__init__()` method of an instance class cannot be invoked in TorchScript code. Because `TestModule` instances are instantiated in Python, in this example, `TestModule(2.0)` and `TestModule(2)` create two instances with different types for its data attributes. `self.x` is of type `float` for `TestModule(2.0)`, whereas `self.y` is of type `int` for `TestModule(2.0)`. -- TorchScript automatically compiles other methods (e.g., `mul()`) invoked by methods annotated via `@torch.jit.export` or `forward()` methods. -- Entry-points to a TorchScript program are either `forward()` of a module type, functions annotated as `torch.jit.script`, or methods annotated as `torch.jit.export`. - -```{eval-rst} -.. testcode:: - - import torch - - class TestModule(torch.nn.Module): - def __init__(self, v): - super().__init__() - self.x = v - - def forward(self, inc: int): - return self.x + inc - - m = torch.jit.script(TestModule(1)) - print(f"First instance: {m(3)}") - - m = torch.jit.script(TestModule(torch.ones([5]))) - print(f"Second instance: {m(3)}") -``` - -The example above produces the following output: - -```{eval-rst} -.. testoutput:: - - First instance: 4 - Second instance: tensor([4., 4., 4., 4., 4.]) -``` - -**Example 2** - -The following example shows an incorrect usage of module type. Specifically, this example invokes the constructor of `TestModule` inside the scope of TorchScript: - -```{eval-rst} -.. testcode:: - - import torch - - class TestModule(torch.nn.Module): - def __init__(self, v): - super().__init__() - self.x = v - - def forward(self, x: int): - return self.x + x - - class MyModel: - def __init__(self, v: int): - self.val = v - - @torch.jit.export - def doSomething(self, val: int) -> int: - # error: should not invoke the constructor of module type - myModel = TestModule(self.val) - return myModel(val) - - # m = torch.jit.script(MyModel(2)) # Results in below RuntimeError - # RuntimeError: Could not get name of python class object -``` - -(type-annotation)= - -## Type Annotation - -Since TorchScript is statically typed, programmers need to annotate types at *strategic points* of TorchScript code so that every local variable or -instance data attribute has a static type, and every function and method has a statically typed signature. - -### When to Annotate Types - -In general, type annotations are only needed in places where static types cannot be automatically inferred (e.g., parameters or sometimes return types to -methods or functions). Types of local variables and data attributes are often automatically inferred from their assignment statements. Sometimes an inferred type -may be too restrictive, e.g., `x` being inferred as `NoneType` through assignment `x = None`, whereas `x` is actually used as an `Optional`. In such -cases, type annotations may be needed to overwrite auto inference, e.g., `x: Optional[int] = None`. Note that it is always safe to type annotate a local variable -or data attribute even if its type can be automatically inferred. The annotated type must be congruent with TorchScript’s type-checking. - -When a parameter, local variable, or data attribute is not type annotated and its type cannot be automatically inferred, TorchScript assumes it to be a -default type of `TensorType`, `List[TensorType]`, or `Dict[str, TensorType]`. - -### Annotate Function Signature - -Since a parameter may not be automatically inferred from the body of the function (including both functions and methods), they need to be type annotated. Otherwise, they assume the default type `TensorType`. - -TorchScript supports two styles for method and function signature type annotation: - -- **Python3-style** annotates types directly on the signature. As such, it allows individual parameters to be left unannotated (whose type will be the default type of `TensorType`), or allows the return type to be left unannotated (whose type will be automatically inferred). - -``` -Python3Annotation ::= "def" Identifier [ "(" ParamAnnot* ")" ] [ReturnAnnot] ":" - FuncOrMethodBody -ParamAnnot ::= Identifier [ ":" TSType ] "," -ReturnAnnot ::= "->" TSType -``` - -Note that when using Python3 style, the type `self` is automatically inferred and should not be annotated. - -- **Mypy style** annotates types as a comment right below the function/method declaration. In the Mypy style, since parameter names do not appear in the annotation, all parameters have to be annotated. - -``` -MyPyAnnotation ::= "# type:" "(" ParamAnnot* ")" [ ReturnAnnot ] -ParamAnnot ::= TSType "," -ReturnAnnot ::= "->" TSType -``` - -**Example 1** - -In this example: - -- `a` is not annotated and assumes the default type of `TensorType`. -- `b` is annotated as type `int`. -- The return type is not annotated and is automatically inferred as type `TensorType` (based on the type of the value being returned). - -```python -import torch - -def f(a, b: int): - return a+b - -m = torch.jit.script(f) -print("TorchScript:", m(torch.ones([6]), 100)) -``` - -**Example 2** - -The following example uses Mypy style annotation. Note that parameters or return values must be annotated even if some of -them assume the default type. - -```python -import torch - -def f(a, b): - # type: (torch.Tensor, int) → torch.Tensor - return a+b - -m = torch.jit.script(f) -print("TorchScript:", m(torch.ones([6]), 100)) -``` - -### Annotate Variables and Data Attributes - -In general, types of data attributes (including class and instance data attributes) and local variables can be automatically inferred from assignment statements. -Sometimes, however, if a variable or attribute is associated with values of different types (e.g., as `None` or `TensorType`), then they may need to be explicitly -type annotated as a *wider* type such as `Optional[int]` or `Any`. - -#### Local Variables - -Local variables can be annotated according to Python3 typing module annotation rules, i.e., - -``` -LocalVarAnnotation ::= Identifier [":" TSType] "=" Expr -``` - -In general, types of local variables can be automatically inferred. In some cases, however, you may need to annotate a multi-type for local variables -that may be associated with different concrete types. Typical multi-types include `Optional[T]` and `Any`. - -**Example** - -```python -import torch - -def f(a, setVal: bool): - value: Optional[torch.Tensor] = None - if setVal: - value = a - return value - -ones = torch.ones([6]) -m = torch.jit.script(f) -print("TorchScript:", m(ones, True), m(ones, False)) -``` - -#### Instance Data Attributes - -For `ModuleType` classes, instance data attributes can be annotated according to Python3 typing module annotation rules. Instance data attributes can be annotated (optionally) as final -via `Final`. - -``` -"class" ClassIdentifier "(torch.nn.Module):" -InstanceAttrIdentifier ":" ["Final("] TSType [")"] -... -``` - -Where: - -- `InstanceAttrIdentifier` is the name of an instance attribute. -- `Final` indicates that the attribute cannot be re-assigned outside of `__init__` or overridden in subclasses. - -**Example** - -```python -import torch - -class MyModule(torch.nn.Module): - offset_: int - -def __init__(self, offset): - self.offset_ = offset - -... -``` - -### Type Annotation APIs - -#### `torch.jit.annotate(T, expr)` - -This API annotates type `T` to an expression `expr`. This is often used when the default type of an expression is not the type intended by the programmer. -For instance, an empty list (dictionary) has the default type of `List[TensorType]` (`Dict[TensorType, TensorType]`), but sometimes it may be used to initialize -a list of some other types. Another common use case is for annotating the return type of `tensor.tolist()`. Note, however, that it cannot be used to annotate -the type of a module attribute in `__init__`; `torch.jit.Attribute` should be used for this instead. - -**Example** - -In this example, `[]` is declared as a list of integers via `torch.jit.annotate` (instead of assuming `[]` to be the default type of `List[TensorType]`): - -```python -import torch -from typing import List - -def g(l: List[int], val: int): - l.append(val) - return l - -def f(val: int): - l = g(torch.jit.annotate(List[int], []), val) - return l - -m = torch.jit.script(f) -print("Eager:", f(3)) -print("TorchScript:", m(3)) -``` - -See {meth}`torch.jit.annotate` for more information. - -### Type Annotation Appendix - -#### TorchScript Type System Definition - -``` -TSAllType ::= TSType | TSModuleType -TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType - -TSMetaType ::= "Any" -TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None" - -TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | TSOptional | - TSUnion | TSFuture | TSRRef | TSAwait -TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]" -TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")" -TSList ::= "List" "[" TSType "]" -TSOptional ::= "Optional" "[" TSType "]" -TSUnion ::= "Union" "[" (TSType ",")* TSType "]" -TSFuture ::= "Future" "[" TSType "]" -TSRRef ::= "RRef" "[" TSType "]" -TSAwait ::= "Await" "[" TSType "]" -TSDict ::= "Dict" "[" KeyType "," TSType "]" -KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any" - -TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum -TSBuiltinClass ::= TSTensor | "torch.device" | "torch.stream"| - "torch.dtype" | "torch.nn.ModuleList" | - "torch.nn.ModuleDict" | ... -TSTensor ::= "torch.tensor" and subclasses -``` - -#### Unsupported Typing Constructs - -TorchScript does not support all features and types of the Python3 [typing](https://docs.python.org/3/library/typing.html#module-typing) module. -Any functionality from the [typing](https://docs.python.org/3/library/typing.html#module-typing) module that is not explicitly specified in this -documentation is unsupported. The following table summarizes `typing` constructs that are either unsupported or supported with restrictions in TorchScript. - -```{eval-rst} -============================= ================ - Item Description ------------------------------ ---------------- -``typing.Any`` In development -``typing.NoReturn`` Not supported -``typing.Callable`` Not supported -``typing.Literal`` Not supported -``typing.ClassVar`` Not supported -``typing.Final`` Supported for module attributes, class attribute, and annotations, but not for functions. -``typing.AnyStr`` Not supported -``typing.overload`` In development -Type aliases Not supported -Nominal typing In development -Structural typing Not supported -NewType Not supported -Generics Not supported -============================= ================ -``` - -(expressions)= - -## Expressions - -The following section describes the grammar of expressions that are supported in TorchScript. -It is modeled after [the expressions chapter of the Python language reference](https://docs.python.org/3/reference/expressions.html). - -### Arithmetic Conversions - -There are a number of implicit type conversions that are performed in TorchScript: - -- A `Tensor` with a `float` or `int` data type can be implicitly converted to an instance of `FloatType` or `IntType` provided that it has a size of 0, does not have `require_grad` set to `True`, and will not require narrowing. -- Instances of `StringType` can be implicitly converted to `DeviceType`. -- The implicit conversion rules from the two bullet points above can be applied to instances of `TupleType` to produce instances of `ListType` with the appropriate contained type. - -Explicit conversions can be invoked using the `float`, `int`, `bool`, and `str` built-in functions -that accept primitive data types as arguments and can accept user-defined types if they implement -`__bool__`, `__str__`, etc. - -### Atoms - -Atoms are the most basic elements of expressions. - -``` -atom ::= identifier | literal | enclosure -enclosure ::= parenth_form | list_display | dict_display -``` - -#### Identifiers - -The rules that dictate what is a legal identifier in TorchScript are the same as -their [Python counterparts](https://docs.python.org/3/reference/lexical_analysis.html#identifiers). - -#### Literals - -``` -literal ::= stringliteral | integer | floatnumber -``` - -Evaluation of a literal yields an object of the appropriate type with the specific value -(with approximations applied as necessary for floats). Literals are immutable, and multiple evaluations -of identical literals may obtain the same object or distinct objects with the same value. -[stringliteral](https://docs.python.org/3/reference/lexical_analysis.html#string-and-bytes-literals), -[integer](https://docs.python.org/3/reference/lexical_analysis.html#integer-literals), and -[floatnumber](https://docs.python.org/3/reference/lexical_analysis.html#floating-point-literals) -are defined in the same way as their Python counterparts. - -#### Parenthesized Forms - -``` -parenth_form ::= '(' [expression_list] ')' -``` - -A parenthesized expression list yields whatever the expression list yields. If the list contains at least one -comma, it yields a `Tuple`; otherwise, it yields the single expression inside the expression list. An empty -pair of parentheses yields an empty `Tuple` object (`Tuple[]`). - -#### List and Dictionary Displays - -``` -list_comprehension ::= expression comp_for -comp_for ::= 'for' target_list 'in' or_expr -list_display ::= '[' [expression_list | list_comprehension] ']' -dict_display ::= '{' [key_datum_list | dict_comprehension] '}' -key_datum_list ::= key_datum (',' key_datum)* -key_datum ::= expression ':' expression -dict_comprehension ::= key_datum comp_for -``` - -Lists and dicts can be constructed by either listing the container contents explicitly or by providing -instructions on how to compute them via a set of looping instructions (i.e. a *comprehension*). A comprehension -is semantically equivalent to using a for loop and appending to an ongoing list. -Comprehensions implicitly create their own scope to make sure that the items of the target list do not leak into the -enclosing scope. In the case that container items are explicitly listed, the expressions in the expression list -are evaluated left-to-right. If a key is repeated in a `dict_display` that has a `key_datum_list`, the -resultant dictionary uses the value from the rightmost datum in the list that uses the repeated key. - -### Primaries - -``` -primary ::= atom | attributeref | subscription | slicing | call -``` - -#### Attribute References - -``` -attributeref ::= primary '.' identifier -``` - -The `primary` must evaluate to an object of a type that supports attribute references that have an attribute named -`identifier`. - -#### Subscriptions - -``` -subscription ::= primary '[' expression_list ']' -``` - -The `primary` must evaluate to an object that supports subscription. - -- If the primary is a `List`, `Tuple`, or `str`, the expression list must evaluate to an integer or slice. -- If the primary is a `Dict`, the expression list must evaluate to an object of the same type as the key type of the `Dict`. -- If the primary is a `ModuleList`, the expression list must be an `integer` literal. -- If the primary is a `ModuleDict`, the expression must be a `stringliteral`. - -#### Slicings - -A slicing selects a range of items in a `str`, `Tuple`, `List`, or `Tensor`. Slicings may be used as -expressions or targets in assignment or `del` statements. - -``` -slicing ::= primary '[' slice_list ']' -slice_list ::= slice_item (',' slice_item)* [','] -slice_item ::= expression | proper_slice -proper_slice ::= [expression] ':' [expression] [':' [expression] ] -``` - -Slicings with more than one slice item in their slice lists can only be used with primaries that evaluate to an -object of type `Tensor`. - -#### Calls - -``` -call ::= primary '(' argument_list ')' -argument_list ::= args [',' kwargs] | kwargs -args ::= [arg (',' arg)*] -kwargs ::= [kwarg (',' kwarg)*] -kwarg ::= arg '=' expression -arg ::= identifier -``` - -The `primary` must desugar or evaluate to a callable object. All argument expressions are evaluated -before the call is attempted. - -### Power Operator - -``` -power ::= primary ['**' u_expr] -``` - -The power operator has the same semantics as the built-in pow function (not supported); it computes its -left argument raised to the power of its right argument. It binds more tightly than unary operators on the -left, but less tightly than unary operators on the right; i.e. `-2 ** -3 == -(2 ** (-3))`. The left and right -operands can be `int`, `float` or `Tensor`. Scalars are broadcast in the case of scalar-tensor/tensor-scalar -exponentiation operations, and tensor-tensor exponentiation is done elementwise without any broadcasting. - -### Unary and Arithmetic Bitwise Operations - -``` -u_expr ::= power | '-' power | '~' power -``` - -The unary `-` operator yields the negation of its argument. The unary `~` operator yields the bitwise inversion -of its argument. `-` can be used with `int`, `float`, and `Tensor` of `int` and `float`. -`~` can only be used with `int` and `Tensor` of `int`. - -### Binary Arithmetic Operations - -``` -m_expr ::= u_expr | m_expr '*' u_expr | m_expr '@' m_expr | m_expr '//' u_expr | m_expr '/' u_expr | m_expr '%' u_expr -a_expr ::= m_expr | a_expr '+' m_expr | a_expr '-' m_expr -``` - -The binary arithmetic operators can operate on `Tensor`, `int`, and `float`. For tensor-tensor ops, both arguments must -have the same shape. For scalar-tensor or tensor-scalar ops, the scalar is usually broadcast to the size of the -tensor. Division ops can only accept scalars as their right-hand side argument, and do not support broadcasting. -The `@` operator is for matrix multiplication and only operates on `Tensor` arguments. The multiplication operator -(`*`) can be used with a list and integer in order to get a result that is the original list repeated a certain -number of times. - -### Shifting Operations - -``` -shift_expr ::= a_expr | shift_expr ( '<<' | '>>' ) a_expr -``` - -These operators accept two `int` arguments, two `Tensor` arguments, or a `Tensor` argument and an `int` or -`float` argument. In all cases, a right shift by `n` is defined as floor division by `pow(2, n)`, and a left shift -by `n` is defined as multiplication by `pow(2, n)`. When both arguments are `Tensors`, they must have the same -shape. When one is a scalar and the other is a `Tensor`, the scalar is logically broadcast to match the size of -the `Tensor`. - -### Binary Bitwise Operations - -``` -and_expr ::= shift_expr | and_expr '&' shift_expr -xor_expr ::= and_expr | xor_expr '^' and_expr -or_expr ::= xor_expr | or_expr '|' xor_expr -``` - -The `&` operator computes the bitwise AND of its arguments, the `^` the bitwise XOR, and the `|` the bitwise OR. -Both operands must be `int` or `Tensor`, or the left operand must be `Tensor` and the right operand must be -`int`. When both operands are `Tensor`, they must have the same shape. When the right operand is `int`, and -the left operand is `Tensor`, the right operand is logically broadcast to match the shape of the `Tensor`. - -### Comparisons - -``` -comparison ::= or_expr (comp_operator or_expr)* -comp_operator ::= '<' | '>' | '==' | '>=' | '<=' | '!=' | 'is' ['not'] | ['not'] 'in' -``` - -A comparison yields a boolean value (`True` or `False`), or if one of the operands is a `Tensor`, a boolean -`Tensor`. Comparisons can be chained arbitrarily as long as they do not yield boolean `Tensors` that have more -than one element. `a op1 b op2 c ...` is equivalent to `a op1 b and b op2 c and ...`. - -#### Value Comparisons - -The operators `<`, `>`, `==`, `>=`, `<=`, and `!=` compare the values of two objects. The two objects generally need to be of -the same type, unless there is an implicit type conversion available between the objects. User-defined types can -be compared if rich comparison methods (e.g., `__lt__`) are defined on them. Built-in type comparison works like -Python: - -- Numbers are compared mathematically. -- Strings are compared lexicographically. -- `lists`, `tuples`, and `dicts` can be compared only to other `lists`, `tuples`, and `dicts` of the same type and are compared using the comparison operator of corresponding elements. - -#### Membership Test Operations - -The operators `in` and `not in` test for membership. `x in s` evaluates to `True` if `x` is a member of `s` and `False` otherwise. -`x not in s` is equivalent to `not x in s`. This operator is supported for `lists`, `dicts`, and `tuples`, and can be used with -user-defined types if they implement the `__contains__` method. - -#### Identity Comparisons - -For all types except `int`, `double`, `bool`, and `torch.device`, operators `is` and `is not` test for the object’s identity; -`x is y` is `True` if and only if `x` and `y` are the same object. For all other types, `is` is equivalent to -comparing them using `==`. `x is not y` yields the inverse of `x is y`. - -### Boolean Operations - -``` -or_test ::= and_test | or_test 'or' and_test -and_test ::= not_test | and_test 'and' not_test -not_test ::= 'bool' '(' or_expr ')' | comparison | 'not' not_test -``` - -User-defined objects can customize their conversion to `bool` by implementing a `__bool__` method. The operator `not` -yields `True` if its operand is false, `False` otherwise. The expression `x` and `y` first evaluates `x`; if it is `False`, its -value (`False`) is returned; otherwise, `y` is evaluated and its value is returned (`False` or `True`). The expression `x` or `y` -first evaluates `x`; if it is `True`, its value (`True`) is returned; otherwise, `y` is evaluated and its value is returned -(`False` or `True`). - -### Conditional Expressions - -``` -conditional_expression ::= or_expr ['if' or_test 'else' conditional_expression] -expression ::= conditional_expression -``` - -The expression `x if c else y` first evaluates the condition `c` rather than x. If `c` is `True`, `x` is -evaluated and its value is returned; otherwise, `y` is evaluated and its value is returned. As with if-statements, -`x` and `y` must evaluate to a value of the same type. - -### Expression Lists - -``` -expression_list ::= expression (',' expression)* [','] -starred_item ::= '*' primary -``` - -A starred item can only appear on the left-hand side of an assignment statement, e.g., `a, *b, c = ...`. - -% statements: - -## Simple Statements - -The following section describes the syntax of simple statements that are supported in TorchScript. -It is modeled after [the simple statements chapter of the Python language reference](https://docs.python.org/3/reference/simple_stmts.html). - -### Expression Statements - -``` -expression_stmt ::= starred_expression -starred_expression ::= expression | (starred_item ",")* [starred_item] -starred_item ::= assignment_expression | "*" or_expr -``` - -### Assignment Statements - -``` -assignment_stmt ::= (target_list "=")+ (starred_expression) -target_list ::= target ("," target)* [","] -target ::= identifier - | "(" [target_list] ")" - | "[" [target_list] "]" - | attributeref - | subscription - | slicing - | "*" target -``` - -### Augmented Assignment Statements - -``` -augmented_assignment_stmt ::= augtarget augop (expression_list) -augtarget ::= identifier | attributeref | subscription -augop ::= "+=" | "-=" | "*=" | "/=" | "//=" | "%=" | - "**="| ">>=" | "<<=" | "&=" | "^=" | "|=" -``` - -### Annotated Assignment Statements - -``` -annotated_assignment_stmt ::= augtarget ":" expression - ["=" (starred_expression)] -``` - -### The `raise` Statement - -``` -raise_stmt ::= "raise" [expression ["from" expression]] -``` - -Raise statements in TorchScript do not support `try\except\finally`. - -### The `assert` Statement - -``` -assert_stmt ::= "assert" expression ["," expression] -``` - -Assert statements in TorchScript do not support `try\except\finally`. - -### The `return` Statement - -``` -return_stmt ::= "return" [expression_list] -``` - -Return statements in TorchScript do not support `try\except\finally`. - -### The `del` Statement - -``` -del_stmt ::= "del" target_list -``` - -### The `pass` Statement - -``` -pass_stmt ::= "pass" -``` - -### The `print` Statement - -``` -print_stmt ::= "print" "(" expression [, expression] [.format{expression_list}] ")" -``` - -### The `break` Statement - -``` -break_stmt ::= "break" -``` - -### The `continue` Statement: - -``` -continue_stmt ::= "continue" -``` - -## Compound Statements - -The following section describes the syntax of compound statements that are supported in TorchScript. -The section also highlights how Torchscript differs from regular Python statements. -It is modeled after [the compound statements chapter of the Python language reference](https://docs.python.org/3/reference/compound_stmts.html). - -### The `if` Statement - -Torchscript supports both basic `if/else` and ternary `if/else`. - -#### Basic `if/else` Statement - -``` -if_stmt ::= "if" assignment_expression ":" suite - ("elif" assignment_expression ":" suite) - ["else" ":" suite] -``` - -`elif` statements can repeat for an arbitrary number of times, but it needs to be before `else` statement. - -#### Ternary `if/else` Statement - -``` -if_stmt ::= return [expression_list] "if" assignment_expression "else" [expression_list] -``` - -**Example 1** - -A `tensor` with 1 dimension is promoted to `bool`: - -```{eval-rst} -.. testcode:: - - import torch - - @torch.jit.script - def fn(x: torch.Tensor): - if x: # The tensor gets promoted to bool - return True - return False - print(fn(torch.rand(1))) -``` - -The example above produces the following output: - -```{eval-rst} -.. testoutput:: - - True -``` - -**Example 2** - -A `tensor` with multi dimensions are not promoted to `bool`: - -```python -import torch - -# Multi dimensional Tensors error out. - -@torch.jit.script -def fn(): - if torch.rand(2): - print("Tensor is available") - - if torch.rand(4,5,6): - print("Tensor is available") - -print(fn()) -``` - -Running the above code yields the following `RuntimeError`. - -``` -RuntimeError: The following operation failed in the TorchScript interpreter. -Traceback of TorchScript (most recent call last): -@torch.jit.script -def fn(): - if torch.rand(2): - ~~~~~~~~~~~~ <--- HERE - print("Tensor is available") -RuntimeError: Boolean value of Tensor with more than one value is ambiguous -``` - -If a conditional variable is annotated as `final`, either the true or false branch is evaluated depending on the evaluation of the conditional variable. - -**Example 3** - -In this example, only the True branch is evaluated, since `a` is annotated as `final` and set to `True`: - -```python -import torch - -a : torch.jit.final[Bool] = True - -if a: - return torch.empty(2,3) -else: - return [] -``` - -### The `while` Statement - -``` -while_stmt ::= "while" assignment_expression ":" suite -``` - -`while...else` statements are not supported in Torchscript. It results in a `RuntimeError`. - -### The `for-in` Statement - -``` -for_stmt ::= "for" target_list "in" expression_list ":" suite - ["else" ":" suite] -``` - -`for...else` statements are not supported in Torchscript. It results in a `RuntimeError`. - -**Example 1** - -For loops on tuples: these unroll the loop, generating a body for each member of the tuple. The body must type-check correctly for each member. - -```{eval-rst} -.. testcode:: - - import torch - from typing import Tuple - - @torch.jit.script - def fn(): - tup = (3, torch.ones(4)) - for x in tup: - print(x) - - fn() -``` - -The example above produces the following output: - -```{eval-rst} -.. testoutput:: - - 3 - 1 - 1 - 1 - 1 - [ CPUFloatType{4} ] - -``` - -**Example 2** - -For loops on lists: for loops over a `nn.ModuleList` will unroll the body of the loop at compile time, with each member of the module list. - -```python -class SubModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.weight = nn.Parameter(torch.randn(2)) - - def forward(self, input): - return self.weight + input - -class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.mods = torch.nn.ModuleList([SubModule() for i in range(10)]) - - def forward(self, v): - for module in self.mods: - v = module(v) - return v - -model = torch.jit.script(MyModule()) -``` - -### The `with` Statement - -The `with` statement is used to wrap the execution of a block with methods defined by a context manager. - -``` -with_stmt ::= "with" with_item ("," with_item) ":" suite -with_item ::= expression ["as" target] -``` - -- If a target was included in the `with` statement, the return value from the context manager’s `__enter__()` is assigned to it. Unlike python, if an exception caused the suite to be exited, its type, value, and traceback are not passed as arguments to `__exit__()`. Three `None` arguments are supplied. -- `try`, `except`, and `finally` statements are not supported inside `with` blocks. -- Exceptions raised within `with` block cannot be suppressed. - -### The `tuple` Statement - -``` -tuple_stmt ::= tuple([iterables]) -``` - -- Iterable types in TorchScript include `Tensors`, `lists`, `tuples`, `dictionaries`, `strings`, `torch.nn.ModuleList`, and `torch.nn.ModuleDict`. -- You cannot convert a List to Tuple by using this built-in function. - -Unpacking all outputs into a tuple is covered by: - -``` -abc = func() # Function that returns a tuple -a,b = func() -``` - -### The `getattr` Statement - -``` -getattr_stmt ::= getattr(object, name[, default]) -``` - -- Attribute name must be a literal string. -- Module type object is not supported (e.g., torch.\_C). -- Custom class object is not supported (e.g., torch.classes.\*). - -### The `hasattr` Statement - -``` -hasattr_stmt ::= hasattr(object, name) -``` - -- Attribute name must be a literal string. -- Module type object is not supported (e.g., torch.\_C). -- Custom class object is not supported (e.g., torch.classes.\*). - -### The `zip` Statement - -``` -zip_stmt ::= zip(iterable1, iterable2) -``` - -- Arguments must be iterables. -- Two iterables of same outer container type but different length are supported. - -**Example 1** - -Both the iterables must be of the same container type: - -```{eval-rst} -.. testcode:: - - a = [1, 2] # List - b = [2, 3, 4] # List - zip(a, b) # works -``` - -**Example 2** - -This example fails because the iterables are of different container types: - -``` -a = (1, 2) # Tuple -b = [2, 3, 4] # List -zip(a, b) # Runtime error -``` - -Running the above code yields the following `RuntimeError`. - -``` -RuntimeError: Can not iterate over a module list or - tuple with a value that does not have a statically determinable length. -``` - -**Example 3** - -Two iterables of the same container Type but different data type is supported: - -```{eval-rst} -.. testcode:: - - a = [1.3, 2.4] - b = [2, 3, 4] - zip(a, b) # Works -``` - -Iterable types in TorchScript include `Tensors`, `lists`, `tuples`, `dictionaries`, `strings`, `torch.nn.ModuleList`, and `torch.nn.ModuleDict`. - -### The `enumerate` Statement - -``` -enumerate_stmt ::= enumerate([iterable]) -``` - -- Arguments must be iterables. -- Iterable types in TorchScript include `Tensors`, `lists`, `tuples`, `dictionaries`, `strings`, `torch.nn.ModuleList` and `torch.nn.ModuleDict`. - -(python-values-torch-script)= - -## Python Values - -(python-builtin-functions-values-resolution)= - -### Resolution Rules - -When given a Python value, TorchScript attempts to resolve it in the following five different ways: - -- Compilable Python Implementation: - : - When a Python value is backed by a Python implementation that can be compiled by TorchScript, TorchScript compiles and uses the underlying Python implementation. - - Example: `torch.jit.Attribute` -- Op Python Wrapper: - : - When a Python value is a wrapper of a native PyTorch op, TorchScript emits the corresponding operator. - - Example: `torch.jit._logging.add_stat_value` -- Python Object Identity Match: - : - For a limited set of `torch.*` API calls (in the form of Python values) that TorchScript supports, TorchScript attempts to match a Python value against each item in the set. - - When matched, TorchScript generates a corresponding `SugaredValue` instance that contains lowering logic for these values. - - Example: `torch.jit.isinstance()` -- Name Match: - : - For Python built-in functions and constants, TorchScript identifies them by name, and creates a corresponding `SugaredValue` instance that implements their functionality. - - Example: `all()` -- Value Snapshot: - : - For Python values from unrecognized modules, TorchScript attempts to take a snapshot of the value and converts it to a constant in the graph of the function(s) or method(s) that are being compiled. - - Example: `math.pi` - -(python-builtin-functions-support)= - -### Python Built-in Functions Support - -```{eval-rst} -.. list-table:: TorchScript Support for Python Built-in Functions - :widths: 25 25 50 - :header-rows: 1 - - * - Built-in Function - - Support Level - - Notes - * - ``abs()`` - - Partial - - Only supports ``Tensor``/``Int``/``Float`` type inputs. | Doesn't honor ``__abs__`` override. - * - ``all()`` - - Full - - - * - ``any()`` - - Full - - - * - ``ascii()`` - - None - - - * - ``bin()`` - - Partial - - Only supports ``Int`` type input. - * - ``bool()`` - - Partial - - Only supports ``Tensor``/``Int``/``Float`` type inputs. - * - ``breakpoint()`` - - None - - - * - ``bytearray()`` - - None - - - * - ``bytes()`` - - None - - - * - ``callable()`` - - None - - - * - ``chr()`` - - Partial - - Only ASCII character set is supported. - * - ``classmethod()`` - - Full - - - * - ``compile()`` - - None - - - * - ``complex()`` - - None - - - * - ``delattr()`` - - None - - - * - ``dict()`` - - Full - - - * - ``dir()`` - - None - - - * - ``divmod()`` - - Full - - - * - ``enumerate()`` - - Full - - - * - ``eval()`` - - None - - - * - ``exec()`` - - None - - - * - ``filter()`` - - None - - - * - ``float()`` - - Partial - - Doesn't honor ``__index__`` override. - * - ``format()`` - - Partial - - Manual index specification not supported. | Format type modifier not supported. - * - ``frozenset()`` - - None - - - * - ``getattr()`` - - Partial - - Attribute name must be string literal. - * - ``globals()`` - - None - - - * - ``hasattr()`` - - Partial - - Attribute name must be string literal. - * - ``hash()`` - - Full - - ``Tensor``'s hash is based on identity not numeric value. - * - ``hex()`` - - Partial - - Only supports ``Int`` type input. - * - ``id()`` - - Full - - Only supports ``Int`` type input. - * - ``input()`` - - None - - - * - ``int()`` - - Partial - - ``base`` argument not supported. | Doesn't honor ``__index__`` override. - * - ``isinstance()`` - - Full - - ``torch.jit.isintance`` provides better support when checking against container types like ``Dict[str, int]``. - * - ``issubclass()`` - - None - - - * - ``iter()`` - - None - - - * - ``len()`` - - Full - - - * - ``list()`` - - Full - - - * - ``ord()`` - - Partial - - Only ASCII character set is supported. - * - ``pow()`` - - Full - - - * - ``print()`` - - Partial - - ``separate``, ``end`` and ``file`` arguments are not supported. - * - ``property()`` - - None - - - * - ``range()`` - - Full - - - * - ``repr()`` - - None - - - * - ``reversed()`` - - None - - - * - ``round()`` - - Partial - - ``ndigits`` argument is not supported. - * - ``set()`` - - None - - - * - ``setattr()`` - - None - - - * - ``slice()`` - - Full - - - * - ``sorted()`` - - Partial - - ``key`` argument is not supported. - * - ``staticmethod()`` - - Full - - - * - ``str()`` - - Partial - - ``encoding`` and ``errors`` arguments are not supported. - * - ``sum()`` - - Full - - - * - ``super()`` - - Partial - - It can only be used in ``nn.Module``'s ``__init__`` method. - * - ``type()`` - - None - - - * - ``vars()`` - - None - - - * - ``zip()`` - - Full - - - * - ``__import__()`` - - None - - -``` - -(python-builtin-values-support)= - -### Python Built-in Values Support - -```{eval-rst} -.. list-table:: TorchScript Support for Python Built-in Values - :widths: 25 25 50 - :header-rows: 1 - - * - Built-in Value - - Support Level - - Notes - * - ``False`` - - Full - - - * - ``True`` - - Full - - - * - ``None`` - - Full - - - * - ``NotImplemented`` - - None - - - * - ``Ellipsis`` - - Full - - - -``` - -(torch-apis-in-torchscript)= - -## torch.\* APIs - -(torch-apis-in-torchscript-rpc)= - -### Remote Procedure Calls - -TorchScript supports a subset of RPC APIs that supports running a function on -a specified remote worker instead of locally. - -Specifically, following APIs are fully supported: - -- `torch.distributed.rpc.rpc_sync()` - : - `rpc_sync()` makes a blocking RPC call to run a function on a remote worker. RPC messages are sent and received in parallel to execution of Python code. - - More details about its usage and examples can be found in {meth}`~torch.distributed.rpc.rpc_sync`. -- `torch.distributed.rpc.rpc_async()` - : - `rpc_async()` makes a non-blocking RPC call to run a function on a remote worker. RPC messages are sent and received in parallel to execution of Python code. - - More details about its usage and examples can be found in {meth}`~torch.distributed.rpc.rpc_async`. -- `torch.distributed.rpc.remote()` - : - `remote.()` executes a remote call on a worker and gets a Remote Reference `RRef` as the return value. - - More details about its usage and examples can be found in {meth}`~torch.distributed.rpc.remote`. - -(torch-apis-in-torchscript-async)= - -### Asynchronous Execution - -TorchScript enables you to create asynchronous computation tasks to make better use -of computation resources. This is done via supporting a list of APIs that are -only usable within TorchScript: - -- `torch.jit.fork()` - : - Creates an asynchronous task executing func and a reference to the value of the result of this execution. Fork will return immediately. - - Synonymous to `torch.jit._fork()`, which is only kept for backward compatibility reasons. - - More details about its usage and examples can be found in {meth}`~torch.jit.fork`. -- `torch.jit.wait()` - : - Forces completion of a `torch.jit.Future[T]` asynchronous task, returning the result of the task. - - Synonymous to `torch.jit._wait()`, which is only kept for backward compatibility reasons. - - More details about its usage and examples can be found in {meth}`~torch.jit.wait`. - -(torch-apis-in-torchscript-annotation)= - -### Type Annotations - -TorchScript is statically-typed. It provides and supports a set of utilities to help annotate variables and attributes: - -- `torch.jit.annotate()` - : - Provides a type hint to TorchScript where Python 3 style type hints do not work well. - - One common example is to annotate type for expressions like `[]`. `[]` is treated as `List[torch.Tensor]` by default. When a different type is needed, you can use this code to hint TorchScript: `torch.jit.annotate(List[int], [])`. - - More details can be found in {meth}`~torch.jit.annotate` -- `torch.jit.Attribute` - : - Common use cases include providing type hint for `torch.nn.Module` attributes. Because their `__init__` methods are not parsed by TorchScript, `torch.jit.Attribute` should be used instead of `torch.jit.annotate` in the module's `__init__` methods. - - More details can be found in {meth}`~torch.jit.Attribute` -- `torch.jit.Final` - : - An alias for Python's `typing.Final`. `torch.jit.Final` is kept only for backward compatibility reasons. - -(torch-apis-in-torchscript-meta-programming)= - -### Meta Programming - -TorchScript provides a set of utilities to facilitate meta programming: - -- `torch.jit.is_scripting()` - : - Returns a boolean value indicating whether the current program is compiled by `torch.jit.script` or not. - - When used in an `assert` or an `if` statement, the scope or branch where `torch.jit.is_scripting()` evaluates to `False` is not compiled. - - Its value can be evaluated statically at compile time, thus commonly used in `if` statements to stop TorchScript from compiling one of the branches. - - More details and examples can be found in {meth}`~torch.jit.is_scripting` -- `torch.jit.is_tracing()` - : - Returns a boolean value indicating whether the current program is traced by `torch.jit.trace` / `torch.jit.trace_module` or not. - - More details can be found in {meth}`~torch.jit.is_tracing` -- `@torch.jit.ignore` - : - This decorator indicates to the compiler that a function or method should be ignored and left as a Python function. - - This allows you to leave code in your model that is not yet TorchScript compatible. - - If a function decorated by `@torch.jit.ignore` is called from TorchScript, ignored functions will dispatch the call to the Python interpreter. - - Models with ignored functions cannot be exported. - - More details and examples can be found in {meth}`~torch.jit.ignore` -- `@torch.jit.unused` - : - This decorator indicates to the compiler that a function or method should be ignored and replaced with the raising of an exception. - - This allows you to leave code in your model that is not yet TorchScript compatible and still export your model. - - If a function decorated by `@torch.jit.unused` is called from TorchScript, a runtime error will be raised. - - More details and examples can be found in {meth}`~torch.jit.unused` - -(torch-apis-in-torchscript-type-refinement)= - -### Type Refinement - -- `torch.jit.isinstance()` - : - Returns a boolean indicating whether a variable is of the specified type. - - More details about its usage and examples can be found in {meth}`~torch.jit.isinstance`. \ No newline at end of file +:::{warning} +TorchScript is deprecated, please use +[torch.export](https://docs.pytorch.org/docs/stable/export.html) instead. +::: \ No newline at end of file diff --git a/docs/source/jit_python_reference.md b/docs/source/jit_python_reference.md index 1d2b5c78a894..edcd93bb7b0d 100644 --- a/docs/source/jit_python_reference.md +++ b/docs/source/jit_python_reference.md @@ -2,431 +2,7 @@ # Python Language Reference Coverage -This is a 1:1 mapping of the features listed in https://docs.python.org/3/reference/ and their -support in TorchScript. The categorizations are as follows: - -```{list-table} -:widths: 40 40 20 -:header-rows: 1 - -* - Section - - Status - - Note -* - [1. Introduction](https://docs.python.org/3/reference/introduction.html) - - Not Relevant - - -* - [1.1. Alternate Implementations](https://docs.python.org/3/reference/introduction.html#alternate-implementations) - - Not Relevant - - -* - [1.2. Notation](https://docs.python.org/3/reference/introduction.html#notation) - - Not Relevant - - -* - [2. Lexical analysis](https://docs.python.org/3/reference/lexical_analysis.html#) - - Not Relevant - - -* - [2.1. Line structure](https://docs.python.org/3/reference/lexical_analysis.html#line-structure) - - Not Relevant - - -* - [2.1.1. Logical lines](https://docs.python.org/3/reference/lexical_analysis.html#logical-lines) - - Not Relevant - - -* - [2.1.2. Physical lines](https://docs.python.org/3/reference/lexical_analysis.html#physical-lines) - - Supported - - -* - [2.1.3. Comments](https://docs.python.org/3/reference/lexical_analysis.html#comments) - - Supported - - -* - [2.1.4. Encoding declarations](https://docs.python.org/3/reference/lexical_analysis.html#encoding-declarations) - - Not Supported - - TorchScript explicitly don't support unicode -* - [2.1.5. Explicit line joining](https://docs.python.org/3/reference/lexical_analysis.html#explicit-line-joining) - - Supported - - -* - [2.1.6. Implicit line joining](https://docs.python.org/3/reference/lexical_analysis.html#implicit-line-joining) - - Supported - - -* - [2.1.7. Blank lines](https://docs.python.org/3/reference/lexical_analysis.html#blank-lines) - - Supported - - -* - [2.1.8. Indentation](https://docs.python.org/3/reference/lexical_analysis.html#indentation) - - Supported - - -* - [2.1.9. Whitespace between tokens](https://docs.python.org/3/reference/lexical_analysis.html#whitespace-between-tokens) - - Not Relevant - - -* - [2.2. Other tokens](https://docs.python.org/3/reference/lexical_analysis.html#other-tokens) - - Not Relevant - - -* - [2.3. Identifiers and keywords](https://docs.python.org/3/reference/lexical_analysis.html#identifiers) - - Supported - - -* - [2.3.1. Keywords](https://docs.python.org/3/reference/lexical_analysis.html#keywords) - - Supported - - -* - [2.3.2. Reserved classes of identifiers](https://docs.python.org/3/reference/lexical_analysis.html#reserved-classes-of-identifiers) - - Supported - - -* - [2.4. Literals](https://docs.python.org/3/reference/lexical_analysis.html#literals) - - Not Relevant - - -* - [2.4.1. String and Bytes literals](https://docs.python.org/3/reference/lexical_analysis.html#string-and-bytes-literals) - - Supported - - -* - [2.4.2. String literal concatenation](https://docs.python.org/3/reference/lexical_analysis.html#string-literal-concatenation) - - Supported - - -* - [2.4.3. Formatted string literals](https://docs.python.org/3/reference/lexical_analysis.html#formatted-string-literals) - - Partially Supported - - -* - [2.4.4. Numeric literals](https://docs.python.org/3/reference/lexical_analysis.html#numeric-literals) - - Supported - - -* - [2.4.5. Integer literals](https://docs.python.org/3/reference/lexical_analysis.html#integer-literals) - - Supported - - -* - [2.4.6. Floating point literals](https://docs.python.org/3/reference/lexical_analysis.html#floating-point-literals) - - Supported - - -* - [2.4.7. Imaginary literals](https://docs.python.org/3/reference/lexical_analysis.html#imaginary-literals) - - Not Supported - - -* - [2.5. Operators](https://docs.python.org/3/reference/lexical_analysis.html#operators) - - Partially Supported - - Not supported: ``<<``, ``>>``, ``:=`` -* - [2.6. Delimiters](https://docs.python.org/3/reference/lexical_analysis.html#delimiters) - - Partially Supported - - Not supported: ``**=``, ``<<=``, ``>>=``, ``%=``, ``^=``, ``@=``, ``&=``, ``//=``, ``%`` operator for some types (e.g. ``str``\ ) -* - [3. Data model](https://docs.python.org/3/reference/datamodel.html#) - - Not Relevant - - -* - [3.1. Objects, values and types](https://docs.python.org/3/reference/datamodel.html#objects-values-and-types) - - Not Relevant - - -* - [3.2. The standard type hierarchy](https://docs.python.org/3/reference/datamodel.html#the-standard-type-hierarchy) - - Partially Supported - - Not supported: NotImplemented, Ellipsis, numbers.Complex, bytes, byte arrays, sets, frozen sets, generators, coroutines, async generators, modules, I/O objects, internal objects, slice objects ( though slicing is supported), classmethod -* - [3.3. Special method names](https://docs.python.org/3/reference/datamodel.html#special-method-names) - - Supported - - -* - [3.3.1. Basic customization](https://docs.python.org/3/reference/datamodel.html#basic-customization) - - Partially Supported - - Not supported: ``__new__`` , ``__del__`` , ``__bytes__`` , ``__format__`` , ``__hash__`` , -* - [3.3.2. Customizing attribute access](https://docs.python.org/3/reference/datamodel.html#customizing-attribute-access) - - Not Supported - - -* - [3.3.2.1. Customizing module attribute access](https://docs.python.org/3/reference/datamodel.html#customizing-module-attribute-access) - - Not Supported - - -* - [3.3.2.2. Implementing Descriptors](https://docs.python.org/3/reference/datamodel.html#implementing-descriptors) - - Not Supported - - -* - [3.3.2.3. Invoking Descriptors](https://docs.python.org/3/reference/datamodel.html#invoking-descriptors) - - Not Supported - - -* - [3.3.2.4. __slots__](https://docs.python.org/3/reference/datamodel.html#slots) - - Not Supported - - -* - [3.3.2.4.1. Notes on using __slots__](https://docs.python.org/3/reference/datamodel.html#notes-on-using-slots) - - Not Supported - - -* - [3.3.3. Customizing class creation](https://docs.python.org/3/reference/datamodel.html#customizing-class-creation) - - Not Supported - - -* - [3.3.3.1. Metaclasses](https://docs.python.org/3/reference/datamodel.html#metaclasses) - - Not Supported - - -* - [3.3.3.2. Resolving MRO entries](https://docs.python.org/3/reference/datamodel.html#resolving-mro-entries) - - Not Supported - - [`super()`` is not supported -* - [3.3.3.3. Determining the appropriate metaclass](https://docs.python.org/3/reference/datamodel.html#determining-the-appropriate-metaclass) - - Not relevant - - -* - [3.3.3.4. Preparing the class namespace](https://docs.python.org/3/reference/datamodel.html#preparing-the-class-namespace) - - Not relevant - - -* - [3.3.3.5. Executing the class body](https://docs.python.org/3/reference/datamodel.html#executing-the-class-body) - - Not relevant - - -* - [3.3.3.6. Creating the class object](https://docs.python.org/3/reference/datamodel.html#creating-the-class-object) - - Not relevant - - -* - [3.3.3.7. Uses for metaclasses](https://docs.python.org/3/reference/datamodel.html#uses-for-metaclasses) - - Not relevant - - -* - [3.3.4. Customizing instance and subclass checks](https://docs.python.org/3/reference/datamodel.html#customizing-instance-and-subclass-checks) - - Not Supported - - -* - [3.3.5. Emulating generic types](https://docs.python.org/3/reference/datamodel.html#emulating-generic-types) - - Not Supported - - -* - [3.3.6. Emulating callable objects](https://docs.python.org/3/reference/datamodel.html#emulating-callable-objects) - - Supported - - -* - [3.3.7. Emulating container types](https://docs.python.org/3/reference/datamodel.html#emulating-container-types) - - Partially Supported - - Some magic methods not supported (e.g. ``__iter__`` ) -* - [3.3.8. Emulating numeric types](https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types) - - Partially Supported - - Magic methods with swapped operands not supported (``__r*__``) -* - [3.3.9. With Statement Context Managers](https://docs.python.org/3/reference/datamodel.html#with-statement-context-managers) - - Not Supported - - -* - [3.3.10. Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-method-lookup) - - Not relevant - - -* - [3.4. Coroutines](https://docs.python.org/3/reference/datamodel.html#coroutines) - - Not Supported - - -* - [3.4.1. Awaitable Objects](https://docs.python.org/3/reference/datamodel.html#awaitable-objects) - - Not Supported - - -* - [3.4.2. Coroutine Objects](https://docs.python.org/3/reference/datamodel.html#coroutine-objects) - - Not Supported - - -* - [3.4.3. Asynchronous Iterators](https://docs.python.org/3/reference/datamodel.html#asynchronous-iterators) - - Not Supported - - -* - [3.4.4. Asynchronous Context Managers](https://docs.python.org/3/reference/datamodel.html#asynchronous-context-managers) - - Not Supported - - -* - [4. Execution model](https://docs.python.org/3/reference/executionmodel.html#) - - Not Relevant - - -* - [4.1. Structure of a program](https://docs.python.org/3/reference/executionmodel.html#structure-of-a-program) - - Not Relevant - - -* - [4.2. Naming and binding](https://docs.python.org/3/reference/executionmodel.html#naming-and-binding) - - Not Relevant - - Names are bound at compile time in TorchScript -* - [4.2.1. Binding of names](https://docs.python.org/3/reference/executionmodel.html#binding-of-names) - - Not Relevant - - See ``global`` and ``nonlocal`` statements section -* - [4.2.2. Resolution of names](https://docs.python.org/3/reference/executionmodel.html#resolution-of-names) - - Not Relevant - - See ``global`` and ``nonlocal`` statements section -* - [4.2.3. Builtins and restricted execution](https://docs.python.org/3/reference/executionmodel.html#builtins-and-restricted-execution) - - Not Relevant - - -* - [4.2.4. Interaction with dynamic features](https://docs.python.org/3/reference/executionmodel.html#interaction-with-dynamic-features) - - Not Supported - - Python values cannot be captured -* - [4.3. Exceptions](https://docs.python.org/3/reference/executionmodel.html#exceptions) - - Partially Supported - - See ``try`` and ``raise`` statement section -* - [5. The import system](https://docs.python.org/3/reference/import.html) - - Not Relevant - - -* - [6. Expressions](https://docs.python.org/3/reference/expressions.html#) - - Not Relevant - - See expressions section -* - [6.1. Arithmetic conversions](https://docs.python.org/3/reference/expressions.html#arithmetic-conversions) - - Supported - - -* - [6.2. Atoms](https://docs.python.org/3/reference/expressions.html#atoms) - - Not Relevant - - -* - [6.2.1. Identifiers (Names)](https://docs.python.org/3/reference/expressions.html#atom-identifiers) - - Supported - - -* - [6.2.2. Literals](https://docs.python.org/3/reference/expressions.html#literals) - - Partially Supported - - [`bytesliteral``\ , ``imagnumber`` not supported -* - [6.2.3. Parenthesized forms](https://docs.python.org/3/reference/expressions.html#parenthesized-forms) - - Supported - - -* - [6.2.4. Displays for lists, sets and dictionaries](https://docs.python.org/3/reference/expressions.html#displays-for-lists-sets-and-dictionaries) - - Partially Supported - - Not supported: comprehension ifs, async iterators -* - [6.2.5. List displays](https://docs.python.org/3/reference/expressions.html#list-displays) - - Supported - - -* - [6.2.6. Set displays](https://docs.python.org/3/reference/expressions.html#set-displays) - - Not Supported - - -* - [6.2.7. Dictionary displays](https://docs.python.org/3/reference/expressions.html#dictionary-displays) - - Supported - - dict() constructor with kwargs doesn't work, dict comprehensions, dictionary unpacking -* - [6.2.8. Generator expressions](https://docs.python.org/3/reference/expressions.html#generator-expressions) - - Not Supported - - -* - [6.2.9. Yield expressions](https://docs.python.org/3/reference/expressions.html#yield-expressions) - - Not Supported - - -* - [6.2.9.1. Generator-iterator methods](https://docs.python.org/3/reference/expressions.html#generator-iterator-methods) - - Not Supported - - -* - [6.2.9.2. Examples](https://docs.python.org/3/reference/expressions.html#examples) - - Not Supported - - -* - [6.2.9.3. Asynchronous generator functions](https://docs.python.org/3/reference/expressions.html#asynchronous-generator-functions) - - Not Supported - - -* - [6.2.9.4. Asynchronous generator-iterator methods](https://docs.python.org/3/reference/expressions.html#asynchronous-generator-iterator-methods) - - Not Supported - - -* - [6.3. Primaries](https://docs.python.org/3/reference/expressions.html#primaries) - - Supported - - -* - [6.3.1. Attribute references](https://docs.python.org/3/reference/expressions.html#attribute-references) - - Supported - - -* - [6.3.2. Subscriptions](https://docs.python.org/3/reference/expressions.html#subscriptions) - - Supported - - -* - [6.3.3. Slicings](https://docs.python.org/3/reference/expressions.html#slicings) - - Partially Supported - - Tuple slicing with stride is not supported -* - [6.3.4. Calls](https://docs.python.org/3/reference/expressions.html#calls) - - Partially Supported - - Args unpack / kwargs unpack is not supported -* - [6.4. Await expression](https://docs.python.org/3/reference/expressions.html#await-expression) - - Not Supported - - -* - [6.5. The power operator](https://docs.python.org/3/reference/expressions.html#the-power-operator) - - Supported - - -* - [6.6. Unary arithmetic and bitwise operations](https://docs.python.org/3/reference/expressions.html#unary-arithmetic-and-bitwise-operations) - - Partially Supported - - Some bitwise operators are not implemented for primitive types (e.g. ``~x`` where ``x`` is an ``int`` is not currently supported) -* - [6.7. Binary arithmetic operations](https://docs.python.org/3/reference/expressions.html#binary-arithmetic-operations) - - Partially Supported - - See delimiters section -* - [6.8. Shifting operations](https://docs.python.org/3/reference/expressions.html#shifting-operations) - - Not Supported - - -* - [6.9. Binary bitwise operations](https://docs.python.org/3/reference/expressions.html#binary-bitwise-operations) - - Supported - - -* - [6.10. Comparisons](https://docs.python.org/3/reference/expressions.html#comparisons) - - Supported - - -* - [6.10.1. Value comparisons](https://docs.python.org/3/reference/expressions.html#value-comparisons) - - Partially Supported - - Dictionary equality checks are not currently supported -* - [6.10.2. Membership test operations](https://docs.python.org/3/reference/expressions.html#membership-test-operations) - - Partially Supported - - Not supported for TorchScript classes -* - [6.10.3. Identity comparisons](https://docs.python.org/3/reference/expressions.html#is-not) - - Supported - - -* - [6.11. Boolean operations](https://docs.python.org/3/reference/expressions.html#boolean-operations) - - Supported - - -* - [6.12. Conditional expressions](https://docs.python.org/3/reference/expressions.html#conditional-expressions) - - Supported - - -* - [6.13. Lambdas](https://docs.python.org/3/reference/expressions.html#lambda) - - Not Supported - - -* - [6.14. Expression lists](https://docs.python.org/3/reference/expressions.html#expression-lists) - - Partially Supported - - Iterable unpacking not supported -* - [6.15. Evaluation order](https://docs.python.org/3/reference/expressions.html#evaluation-order) - - Supported - - -* - [6.16. Operator precedence](https://docs.python.org/3/reference/expressions.html#operator-precedence) - - Supported - - -* - [7. Simple statements](https://docs.python.org/3/reference/simple_stmts.html#) - - Supported - - -* - [7.1. Expression statements](https://docs.python.org/3/reference/simple_stmts.html#expression-statements) - - Supported - - -* - [7.2. Assignment statements](https://docs.python.org/3/reference/simple_stmts.html#assignment-statements) - - Supported - - -* - [7.2.1. Augmented assignment statements](https://docs.python.org/3/reference/simple_stmts.html#augmented-assignment-statements) - - Partially Supported - - See delimiters section -* - [7.2.2. Annotated assignment statements](https://docs.python.org/3/reference/simple_stmts.html#annotated-assignment-statements) - - Supported - - -* - [7.3. The assert statement](https://docs.python.org/3/reference/simple_stmts.html#the-assert-statement) - - Partially Supported - - Exception message is not customizable -* - [7.4. The pass statement](https://docs.python.org/3/reference/simple_stmts.html#the-pass-statement) - - Supported - - -* - [7.5. The del statement](https://docs.python.org/3/reference/simple_stmts.html#the-del-statement) - - Not Supported - - -* - [7.6. The return statement](https://docs.python.org/3/reference/simple_stmts.html#the-return-statement) - - Supported - - Some other features of returning (e.g. behavior with try..finally) are unsupported -* - [7.7. The yield statement](https://docs.python.org/3/reference/simple_stmts.html#the-yield-statement) - - Not Supported - - -* - [7.8. The raise statement](https://docs.python.org/3/reference/simple_stmts.html#the-raise-statement) - - Partially Supported - - Exception message is not customizable -* - [7.9. The break statement](https://docs.python.org/3/reference/simple_stmts.html#the-break-statement) - - Supported - - Some other features of returning (e.g. behavior with try..finally) are unsupported -* - [7.10. The continue statement](https://docs.python.org/3/reference/simple_stmts.html#the-continue-statement) - - Supported - - Some other features of returning (e.g. behavior with try..finally) are unsupported -* - [7.11. The import statement](https://docs.python.org/3/reference/simple_stmts.html#the-import-statement) - - Not Supported - - -* - [7.11.1. Future statements](https://docs.python.org/3/reference/simple_stmts.html#future-statements) - - Not Supported - - -* - [7.12. The global statement](https://docs.python.org/3/reference/simple_stmts.html#the-global-statement) - - Not Supported - - -* - [7.13. The nonlocal statement](https://docs.python.org/3/reference/simple_stmts.html#the-nonlocal-statement) - - Not Supported - - -* - [8. Compound statements](https://docs.python.org/3/reference/compound_stmts.html#) - - Irrelevant - - -* - [8.1. The if statement](https://docs.python.org/3/reference/compound_stmts.html#the-if-statement) - - Supported - - -* - [8.2. The while statement](https://docs.python.org/3/reference/compound_stmts.html#the-while-statement) - - Partially Supported - - while..else is not supported -* - [8.3. The for statement](https://docs.python.org/3/reference/compound_stmts.html#the-for-statement) - - Partially Supported - - for..else is not supported -* - [8.4. The try statement](https://docs.python.org/3/reference/compound_stmts.html#the-try-statement) - - Not Supported - - -* - [8.5. The with statement](https://docs.python.org/3/reference/compound_stmts.html#the-with-statement) - - Partially Supported - - [`__exit__`` is always called with ``exc_type``, ``exc_value``, and ``traceback`` set to None, even if an exception was raised, and ``__exit__``'s return value is ignored. -* - [8.6. Function definitions](https://docs.python.org/3/reference/compound_stmts.html#function-definitions) - - Not Supported - - -* - [8.7. Class definitions](https://docs.python.org/3/reference/compound_stmts.html#class-definitions) - - Not Supported - - -* - [8.8. Coroutines](https://docs.python.org/3/reference/compound_stmts.html#coroutines) - - Not Supported - - -* - [8.8.1. Coroutine function definition](https://docs.python.org/3/reference/compound_stmts.html#coroutine-function-definition) - - Not Supported - - -* - [8.8.2. The async for statement](https://docs.python.org/3/reference/compound_stmts.html#the-async-for-statement) - - Not Supported - - -* - [8.8.3. The async with statement](https://docs.python.org/3/reference/compound_stmts.html#the-async-with-statement) - - Not Supported - - -* - [9. Top-level components](https://docs.python.org/3/reference/toplevel_components.html#) - - Not Relevant - - -* - [9.1. Complete Python programs](https://docs.python.org/3/reference/toplevel_components.html#complete-python-programs) - - Not Relevant - - -* - [9.2. File input](https://docs.python.org/3/reference/toplevel_components.html#file-input) - - Not Relevant - - -* - [9.3. Interactive input](https://docs.python.org/3/reference/toplevel_components.html#interactive-input) - - Not Relevant - - -* - [9.4. Expression input](https://docs.python.org/3/reference/toplevel_components.html#expression-input) - - Not Relevant - - -``` +:::{warning} +TorchScript is deprecated, please use +[torch.export](https://docs.pytorch.org/docs/stable/export.html) instead. +::: \ No newline at end of file diff --git a/docs/source/jit_unsupported.md b/docs/source/jit_unsupported.md index 79a51c1651f3..bdb930970f51 100644 --- a/docs/source/jit_unsupported.md +++ b/docs/source/jit_unsupported.md @@ -2,80 +2,7 @@ # TorchScript Unsupported PyTorch Constructs -## Torch and Tensor Unsupported Attributes - -TorchScript supports most methods defined on `torch` and `torch.Tensor`, but we do not have full coverage. -Here are specific known ops and categories of ops which have diverging behavior between -Python and TorchScript. If you encounter something else that is not supported please -file a GitHub issue. Deprecated ops are not listed below. - -```{eval-rst} -.. automodule:: torch.jit.unsupported_tensor_ops -``` - -### Functions Not Correctly Bound on Torch - -The following functions will fail if used in TorchScript, either because they -are not bound on `torch` or because Python expects a different schema than -TorchScript. - -- {func}`torch.tensordot` -- {func}`torch.nn.init.calculate_gain` -- {func}`torch.nn.init.eye_` -- {func}`torch.nn.init.dirac_` -- {func}`torch.nn.init.kaiming_normal_` -- {func}`torch.nn.init.orthogonal_` -- {func}`torch.nn.init.sparse` - -### Ops With Divergent Schemas Between Torch & Python - -The following categories of ops have divergent schemas: - -Functions which construct tensors from non-tensor inputs do not support the `requires_grad` -argument, except for `torch.tensor`. This covers the following ops: - -- {func}`torch.norm` -- {func}`torch.bartlett_window` -- {func}`torch.blackman_window` -- {func}`torch.empty` -- {func}`torch.empty_like` -- {func}`torch.empty_strided` -- {func}`torch.eye` -- {func}`torch.full` -- {func}`torch.full_like` -- {func}`torch.hamming_window` -- {func}`torch.hann_window` -- {func}`torch.linspace` -- {func}`torch.logspace` -- {func}`torch.normal` -- {func}`torch.ones` -- {func}`torch.rand` -- {func}`torch.rand_like` -- {func}`torch.randint_like` -- {func}`torch.randn` -- {func}`torch.randn_like` -- {func}`torch.randperm` -- {func}`torch.tril_indices` -- {func}`torch.triu_indices` -- {func}`torch.vander` -- {func}`torch.zeros` -- {func}`torch.zeros_like` - -The following functions require `dtype`, `layout`, `device` as parameters in TorchScript, -but these parameters are optional in Python. - -- {func}`torch.randint` -- {func}`torch.sparse_coo_tensor` -- {func}`torch.Tensor.to` - -## PyTorch Unsupported Modules and Classes - -TorchScript cannot currently compile a number of other commonly used PyTorch -constructs. Below are listed the modules that TorchScript does not support, and -an incomplete list of PyTorch classes that are not supported. For unsupported modules -we suggest using {meth}`torch.jit.trace`. - -- {class}`torch.nn.RNN` -- {class}`torch.nn.AdaptiveLogSoftmaxWithLoss` -- {class}`torch.autograd.Function` -- {class}`torch.autograd.enable_grad` +:::{warning} +TorchScript is deprecated, please use +[torch.export](https://docs.pytorch.org/docs/stable/export.html) instead. +::: diff --git a/docs/source/notes/cpu_threading_runtimes.svg b/docs/source/notes/cpu_threading_runtimes.svg deleted file mode 100644 index e36ec598f063..000000000000 --- a/docs/source/notes/cpu_threading_runtimes.svg +++ /dev/null @@ -1,208 +0,0 @@ - -image/svg+xml0102030400.51.01.52.02.5# ThreadsTime, s diff --git a/docs/source/notes/cpu_threading_torchscript_inference.rst b/docs/source/notes/cpu_threading_torchscript_inference.rst index e4e55dcf2bd3..8cac34c8c36f 100644 --- a/docs/source/notes/cpu_threading_torchscript_inference.rst +++ b/docs/source/notes/cpu_threading_torchscript_inference.rst @@ -3,160 +3,6 @@ CPU threading and TorchScript inference ================================================= -PyTorch allows using multiple CPU threads during TorchScript model inference. -The following figure shows different levels of parallelism one would find in a -typical application: - -.. image:: cpu_threading_torchscript_inference.svg - :width: 75% - -One or more inference threads execute a model's forward pass on the given inputs. -Each inference thread invokes a JIT interpreter that executes the ops -of a model inline, one by one. A model can utilize a ``fork`` TorchScript -primitive to launch an asynchronous task. Forking several operations at once -results in a task that is executed in parallel. The ``fork`` operator returns a -``Future`` object which can be used to synchronize on later, for example: - -.. code-block:: python - - @torch.jit.script - def compute_z(x): - return torch.mm(x, self.w_z) - - @torch.jit.script - def forward(x): - # launch compute_z asynchronously: - fut = torch.jit._fork(compute_z, x) - # execute the next operation in parallel to compute_z: - y = torch.mm(x, self.w_y) - # wait for the result of compute_z: - z = torch.jit._wait(fut) - return y + z - - -PyTorch uses a single thread pool for the inter-op parallelism, this thread pool -is shared by all inference tasks that are forked within the application process. - -In addition to the inter-op parallelism, PyTorch can also utilize multiple threads -within the ops (`intra-op parallelism`). This can be useful in many cases, -including element-wise ops on large tensors, convolutions, GEMMs, embedding -lookups and others. - - -Build options -------------- - -PyTorch uses an internal ATen library to implement ops. In addition to that, -PyTorch can also be built with support of external libraries, such as MKL_ and MKL-DNN_, -to speed up computations on CPU. - -ATen, MKL and MKL-DNN support intra-op parallelism and depend on the -following parallelization libraries to implement it: - -* OpenMP_ - a standard (and a library, usually shipped with a compiler), widely used in external libraries; -* TBB_ - a newer parallelization library optimized for task-based parallelism and concurrent environments. - -OpenMP historically has been used by a large number of libraries. It is known -for a relative ease of use and support for loop-based parallelism and other primitives. - -TBB is used to a lesser extent in external libraries, but, at the same time, -is optimized for the concurrent environments. PyTorch's TBB backend guarantees that -there's a separate, single, per-process intra-op thread pool used by all of the -ops running in the application. - -Depending of the use case, one might find one or another parallelization -library a better choice in their application. - -PyTorch allows selecting of the parallelization backend used by ATen and other -libraries at the build time with the following build options: - -+------------+------------------------+-----------------------------+----------------------------------------+ -| Library | Build Option | Values | Notes | -+============+========================+=============================+========================================+ -| ATen | ``ATEN_THREADING`` | ``OMP`` (default), ``TBB`` | | -+------------+------------------------+-----------------------------+----------------------------------------+ -| MKL | ``MKL_THREADING`` | (same) | To enable MKL use ``BLAS=MKL`` | -+------------+------------------------+-----------------------------+----------------------------------------+ -| MKL-DNN | ``MKLDNN_CPU_RUNTIME`` | (same) | To enable MKL-DNN use ``USE_MKLDNN=1`` | -+------------+------------------------+-----------------------------+----------------------------------------+ - -It is recommended not to mix OpenMP and TBB within one build. - -Any of the ``TBB`` values above require ``USE_TBB=1`` build setting (default: OFF). -A separate setting ``USE_OPENMP=1`` (default: ON) is required for OpenMP parallelism. - -Runtime API ------------ - -The following API is used to control thread settings: - -+------------------------+-----------------------------------------------------------+---------------------------------------------------------+ -| Type of parallelism | Settings | Notes | -+========================+===========================================================+=========================================================+ -| Inter-op parallelism | ``at::set_num_interop_threads``, | Default number of threads: number of CPU cores. | -| | ``at::get_num_interop_threads`` (C++) | | -| | | | -| | ``set_num_interop_threads``, | | -| | ``get_num_interop_threads`` (Python, :mod:`torch` module) | | -+------------------------+-----------------------------------------------------------+ | -| Intra-op parallelism | ``at::set_num_threads``, | | -| | ``at::get_num_threads`` (C++) | | -| | ``set_num_threads``, | | -| | ``get_num_threads`` (Python, :mod:`torch` module) | | -| | | | -| | Environment variables: | | -| | ``OMP_NUM_THREADS`` and ``MKL_NUM_THREADS`` | | -+------------------------+-----------------------------------------------------------+---------------------------------------------------------+ - -For the intra-op parallelism settings, ``at::set_num_threads``, ``torch.set_num_threads`` always take precedence -over environment variables, ``MKL_NUM_THREADS`` variable takes precedence over ``OMP_NUM_THREADS``. - -Tuning the number of threads ----------------------------- - -The following simple script shows how a runtime of matrix multiplication changes with the number of threads: - -.. code-block:: python - - import timeit - runtimes = [] - threads = [1] + [t for t in range(2, 49, 2)] - for t in threads: - torch.set_num_threads(t) - r = timeit.timeit(setup = "import torch; x = torch.randn(1024, 1024); y = torch.randn(1024, 1024)", stmt="torch.mm(x, y)", number=100) - runtimes.append(r) - # ... plotting (threads, runtimes) ... - -Running the script on a system with 24 physical CPU cores (Xeon E5-2680, MKL and OpenMP based build) results in the following runtimes: - -.. image:: cpu_threading_runtimes.svg - :width: 75% - -The following considerations should be taken into account when tuning the number of intra- and inter-op threads: - -* When choosing the number of threads one needs to avoid `oversubscription` (using too many threads, leads to performance degradation). For example, in an application that uses a large application thread pool or heavily relies on - inter-op parallelism, one might find disabling intra-op parallelism as a possible option (i.e. by calling ``set_num_threads(1)``); - -* In a typical application one might encounter a trade off between `latency` (time spent on processing an inference request) and `throughput` (amount of work done per unit of time). Tuning the number of threads can be a useful - tool to adjust this trade off in one way or another. For example, in latency critical applications one might want to increase the number of intra-op threads to process each request as fast as possible. At the same time, parallel implementations - of ops may add an extra overhead that increases amount work done per single request and thus reduces the overall throughput. - .. warning:: - OpenMP does not guarantee that a single per-process intra-op thread - pool is going to be used in the application. On the contrary, two different application or inter-op - threads may use different OpenMP thread pools for intra-op work. - This might result in a large number of threads used by the application. - Extra care in tuning the number of threads is needed to avoid - oversubscription in multi-threaded applications in OpenMP case. - -.. note:: - Pre-built PyTorch releases are compiled with OpenMP support. - -.. note:: - ``parallel_info`` utility prints information about thread settings and can be used for debugging. - Similar output can be also obtained in Python with ``torch.__config__.parallel_info()`` call. - -.. _OpenMP: https://www.openmp.org/ -.. _TBB: https://github.com/intel/tbb -.. _MKL: https://software.intel.com/en-us/mkl -.. _MKL-DNN: https://github.com/intel/mkl-dnn + TorchScript is deprecated, please use + `torch.export `__ instead. diff --git a/docs/source/notes/cpu_threading_torchscript_inference.svg b/docs/source/notes/cpu_threading_torchscript_inference.svg deleted file mode 100644 index f09884cc5f27..000000000000 --- a/docs/source/notes/cpu_threading_torchscript_inference.svg +++ /dev/null @@ -1,681 +0,0 @@ - -image/svg+xml -Inputs -Application Thread Pool - -Op -Op -Op -Inference thread -Fork -Op -Join - - -Inter -- -op parallelism -Intra -- -op parallelism - -ATen/Parallel -(e.g. at::parallel_for) - -MKL - -MKL -- -DNN - -... -OpenMP -TBB - - diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index 98e1d8141dd9..5210eb4ad149 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -64,6 +64,49 @@ Below you can find a small example showcasing this:: TensorFloat-32 (TF32) on Ampere (and later) devices --------------------------------------------------- +After Pytorch 2.9, we provide a new sets of APIs to control the TF32 behavior in a more fine-grained way, and +suggest to use the new APIs for better control. +We can set float32 precision per backend and per operators. We can also override the global setting for a specific operator. + +.. code:: python + + torch.backends.fp32_precision = "ieee" + torch.backends.cuda.matmul.fp32_precision = "ieee" + torch.backends.cudnn.fp32_precision = "ieee" + torch.backends.cudnn.conv.fp32_precision = "tf32" + torch.backends.cudnn.rnn.fp32_precision = "tf32" + +The fp32_precision can be set to `ieee` or `tf32` for `cuda/cudnn`. +`ieee` fp32_precision indicate that we will use `FP32` as internal computation precision. +`tf32` fp32_precision indicate that we will allow to use `TF32` as internal computation precision. + +We can override a generic setting for a specific operator if the fp32_precision is set to `ieee`. + +.. code:: python + + torch.backends.cudnn.fp32_precision = "tf32" + torch.backends.cudnn.conv.fp32_precision = "ieee" + torch.backends.cudnn.rnn.fp32_precision = "ieee" + +We can also override a generic setting for a specific backend if the fp32_precision is set to `ieee`. + +.. code:: python + + torch.backends.fp32_precision = "tf32" + torch.backends.cudnn.fp32_precision = "ieee" + torch.backends.cudnn.conv.fp32_precision = "ieee" + torch.backends.cudnn.rnn.fp32_precision = "ieee" + +For above 2 cases, both `torch.backends.cudnn.conv.fp32_precision` and `torch.backends.cudnn.rnn.fp32_precision` +is overridden to `ieee`. + +We suggest to use the new settings for better control. And we do not support to use mix of old and new settings. + +.. warning:: + + Old settings with `allow_tf32` as follows is going to be deprecated. We suggest to use the above new settings for + better control. And we do not support to use mix of old and new settings. + Starting in PyTorch 1.7, there is a new flag called `allow_tf32`. This flag defaults to True in PyTorch 1.7 to PyTorch 1.11, and False in PyTorch 1.12 and later. This flag controls whether PyTorch is allowed to use the TensorFloat32 (TF32) tensor cores, @@ -133,44 +176,6 @@ To toggle the TF32 flags off in C++, you can do at::globalContext().setAllowTF32CuBLAS(false); at::globalContext().setAllowTF32CuDNN(false); -After Pytorch 2.7, we provide a new sets of APIs to control the TF32 behavior in a more fine-grained way. -We can set float32 precision per backend and per operators. We can also override the global setting for a specific operator. - -.. code:: python - - torch.backends.fp32_precision = "ieee" - torch.backends.cuda.matmul.fp32_precision = "ieee" - torch.backends.cudnn.fp32_precision = "ieee" - torch.backends.cudnn.conv.fp32_precision = "tf32" - torch.backends.cudnn.rnn.fp32_precision = "tf32" - -The fp32_precision can be set to `ieee` or `tf32` for `cuda/cudnn`. -`ieee` fp32_precision indicate that we will use `FP32` as internal computation precision. -`tf32` fp32_precision indicate that we will allow to use `TF32` as internal computation precision. - -We can override a generic setting for a specific operator if the fp32_precision is set to `ieee`. - -.. code:: python - - torch.backends.cudnn.fp32_precision = "tf32" - torch.backends.cudnn.conv.fp32_precision = "ieee" - torch.backends.cudnn.rnn.fp32_precision = "ieee" - -We can also override a generic setting for a specific backend if the fp32_precision is set to `ieee`. - -.. code:: python - - torch.backends.fp32_precision = "tf32" - torch.backends.cudnn.fp32_precision = "ieee" - torch.backends.cudnn.conv.fp32_precision = "ieee" - torch.backends.cudnn.rnn.fp32_precision = "ieee" - -For above 2 cases, both `torch.backends.cudnn.conv.fp32_precision` and `torch.backends.cudnn.rnn.fp32_precision` -is overridden to `ieee`. - -Old settings are still supported. But we suggest to use the new settings for better control. And we do not support -to use mix of old and new settings. - For more information about TF32, see: - `TensorFloat-32`_ diff --git a/docs/source/notes/large_scale_deployments.rst b/docs/source/notes/large_scale_deployments.rst index 2829ba0e939b..27380a68cf33 100644 --- a/docs/source/notes/large_scale_deployments.rst +++ b/docs/source/notes/large_scale_deployments.rst @@ -7,9 +7,6 @@ This note talks about several extension points and tricks that might be useful when running PyTorch within a larger system or operating multiple systems using PyTorch in a larger organization. -It doesn't cover topics of deploying models to production. Check -:mod:`torch.jit` or one of the corresponding tutorials. - The note assumes that you either build PyTorch from source in your organization or have an ability to statically link additional code to be loaded when PyTorch is used. Therefore, many of the hooks are exposed as C++ APIs that @@ -86,8 +83,7 @@ scripts, the callback fires only once for a given process for each of the APIs. ``c10::SetAPIUsageHandler`` can be used to register API usage instrumentation handler. Passed argument is going to be an "api key" identifying used point, for -example ``python.import`` for PyTorch extension import or -``torch.script.compile`` if TorchScript compilation was triggered. +example ``python.import`` for PyTorch extension import. .. code-block:: cpp @@ -99,42 +95,6 @@ Note for developers: new API trigger points can be added in code with ``C10_LOG_API_USAGE_ONCE("my_api")`` in C++ or ``torch._C._log_api_usage_once("my.api")`` in Python. -Attaching metadata to saved TorchScript models -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -TorchScript modules can be saved as an archive file that bundles serialized -parameters and module code as TorchScript (see :meth:`torch.jit.save`). It's -often convenient to bundle additional information together with the model, for -example, description of model producer or auxiliary artifacts. - -It can be achieved by passing the ``_extra_files`` argument to -:meth:`torch.jit.save` and ``torch::jit::load`` to store and retrieve -arbitrary binary blobs during saving process. Since TorchScript files are -regular ZIP archives, extra information gets stored as regular files inside -archive's ``extra/`` directory. - -There's also a global hook allowing to attach extra files to any TorchScript -archive produced in the current process. It might be useful to tag models with -producer metadata, akin to JPEG metadata produced by digital cameras. Example -usage might look like: - -.. code-block:: cpp - - SetExportModuleExtraFilesHook([](const Module&) { - ExtraFilesMap files; - files["producer_info.json"] = "{\"user\": \"" + getenv("USER") + "\"}"; - return files; - }); - - -Build environment considerations -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -TorchScript's compilation needs to have access to the original python files as -it uses python's ``inspect.getsource`` call. In certain production environments -it might require explicitly deploying ``.py`` files along with precompiled -``.pyc``. - Common extension points ^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/notes/mkldnn.rst b/docs/source/notes/mkldnn.rst index 48ee9ce84c35..8e4a26e50bc5 100644 --- a/docs/source/notes/mkldnn.rst +++ b/docs/source/notes/mkldnn.rst @@ -26,7 +26,7 @@ Users can disable MKLDNN backend by: Bfloat16 (BF16) on MKLDNN backend --------------------------------------------------- -Starting in PyTorch 2.4, there is a set of APIs to control the internal computation precision +Starting in PyTorch 2.9, there is a set of APIs to control the internal computation precision for `float32` operators. .. code:: python @@ -65,6 +65,13 @@ To get an idea of the precision and speed, see the example code and benchmark da relative_error = error / mean # 0.0170 print(error, relative_error) + # Do matmul at TF32 mode. + torch.backends.mkldnn.matmul.fp32_precision = 'tf32' + ab_tf32 = a @ b # expected speedup with TF32 dot-product acceleration + error = (ab_tf32 - ab_full).abs().max() # 0.0004 + relative_error = error / mean # 0.00000552 + print(error, relative_error) + # Do matmul FP32 mode. torch.backends.mkldnn.matmul.fp32_precision = 'ieee' ab_fp32 = a @ b diff --git a/docs/source/notes/numerical_accuracy.rst b/docs/source/notes/numerical_accuracy.rst index 2e081a08442d..8944ecc05f27 100644 --- a/docs/source/notes/numerical_accuracy.rst +++ b/docs/source/notes/numerical_accuracy.rst @@ -93,8 +93,8 @@ On Ampere (and later) Nvidia GPUs, PyTorch can use TensorFloat32 (TF32) to speed When an operation is performed using TF32 tensor cores, only the first 10 bits of the input mantissa are read. This may reduce accuracy and produce surprising results (e.g., multiplying a matrix by the identity matrix may produce results that are different from the input). By default, TF32 tensor cores are disabled for matrix multiplications and enabled for convolutions, although most neural network workloads have the same convergence behavior when using TF32 as they have with fp32. -We recommend enabling TF32 tensor cores for matrix multiplications with ``torch.backends.cuda.matmul.allow_tf32 = True`` if your network does not need full float32 precision. -If your network needs full float32 precision for both matrix multiplications and convolutions, then TF32 tensor cores can also be disabled for convolutions with ``torch.backends.cudnn.allow_tf32 = False``. +We recommend enabling TF32 tensor cores for matrix multiplications with ``torch.backends.cuda.matmul.fp32_precision = "tf32"`` (```torch.backends.cuda.matmul.allow_tf32 = True`` is going to be deprecated) if your network does not need full float32 precision. +If your network needs full float32 precision for both matrix multiplications and convolutions, then TF32 tensor cores can also be disabled for convolutions with ``torch.backends.cudnn.conv.fp32_precision = "ieee"`` (``torch.backends.cudnn.allow_tf32 = False`` is going to be deprecated). For more information see :ref:`TensorFloat32`. diff --git a/docs/source/onnx.md b/docs/source/onnx.md index ad436748022b..184dc8740c79 100644 --- a/docs/source/onnx.md +++ b/docs/source/onnx.md @@ -87,7 +87,6 @@ also be interested in reading our [development wiki](https://github.com/pytorch/ onnx_dynamo onnx_ops onnx_verification - onnx_dynamo_onnxruntime_backend onnx_torchscript ``` diff --git a/docs/source/onnx_dynamo_onnxruntime_backend.md b/docs/source/onnx_dynamo_onnxruntime_backend.md deleted file mode 100644 index a59cd4ab919c..000000000000 --- a/docs/source/onnx_dynamo_onnxruntime_backend.md +++ /dev/null @@ -1,11 +0,0 @@ -# ONNX Backend for TorchDynamo - -For a quick overview of `torch.compiler`, see {ref}`torch.compiler_overview`. - -```{warning} - The ONNX backend for torch.compile is a rapidly evolving beta technology. -``` - -```{eval-rst} -.. autofunction:: torch.onnx.is_onnxrt_backend_supported -``` \ No newline at end of file diff --git a/docs/source/optim.aliases.md b/docs/source/optim.aliases.md new file mode 100644 index 000000000000..09616aefe14a --- /dev/null +++ b/docs/source/optim.aliases.md @@ -0,0 +1,144 @@ +# Aliases in torch.optim + +The following are aliases to their counterparts in ``torch.optim`` in the nested namespaces in which they are defined. For any of these APIs, feel free to use the top-level version in ``torch.optim`` like ``torch.optim.Adam`` or the nested version ``torch.optim.adam.Adam``. + +```{eval-rst} +.. automodule:: torch.optim.adadelta +.. currentmodule:: torch.optim.adadelta +.. autosummary:: + :toctree: generated + :nosignatures: + + Adadelta + adadelta +``` + +```{eval-rst} +.. automodule:: torch.optim.adagrad +.. currentmodule:: torch.optim.adagrad +.. autosummary:: + :toctree: generated + :nosignatures: + + Adagrad + adagrad +``` + +```{eval-rst} +.. automodule:: torch.optim.adam +.. currentmodule:: torch.optim.adam +.. autosummary:: + :toctree: generated + :nosignatures: + + Adam + adam +``` + +```{eval-rst} +.. automodule:: torch.optim.adamax +.. currentmodule:: torch.optim.adamax +.. autosummary:: + :toctree: generated + :nosignatures: + + Adamax + adamax +``` + +```{eval-rst} +.. automodule:: torch.optim.adamw +.. currentmodule:: torch.optim.adamw +.. autosummary:: + :toctree: generated + :nosignatures: + + AdamW + adamw +``` + +```{eval-rst} +.. automodule:: torch.optim.asgd +.. currentmodule:: torch.optim.asgd +.. autosummary:: + :toctree: generated + :nosignatures: + + ASGD + asgd +``` + +```{eval-rst} +.. automodule:: torch.optim.lbfgs +.. currentmodule:: torch.optim.lbfgs +.. autosummary:: + :toctree: generated + :nosignatures: + + LBFGS +``` + +```{eval-rst} +.. automodule:: torch.optim.nadam +.. currentmodule:: torch.optim.nadam +.. autosummary:: + :toctree: generated + :nosignatures: + + NAdam + nadam +``` + +```{eval-rst} +.. automodule:: torch.optim.radam +.. currentmodule:: torch.optim.radam +.. autosummary:: + :toctree: generated + :nosignatures: + + RAdam + radam +``` + +```{eval-rst} +.. automodule:: torch.optim.rmsprop +.. currentmodule:: torch.optim.rmsprop +.. autosummary:: + :toctree: generated + :nosignatures: + + RMSprop + rmsprop +``` + +```{eval-rst} +.. automodule:: torch.optim.rprop +.. currentmodule:: torch.optim.rprop +.. autosummary:: + :toctree: generated + :nosignatures: + + Rprop + rprop +``` + +```{eval-rst} +.. automodule:: torch.optim.sgd +.. currentmodule:: torch.optim.sgd +.. autosummary:: + :toctree: generated + :nosignatures: + + SGD + sgd +``` + +```{eval-rst} +.. automodule:: torch.optim.sparse_adam +.. currentmodule:: torch.optim.sparse_adam +.. autosummary:: + :toctree: generated + :nosignatures: + + SparseAdam +``` diff --git a/docs/source/optim.md b/docs/source/optim.md index 8a3f03468810..38587705ed21 100644 --- a/docs/source/optim.md +++ b/docs/source/optim.md @@ -688,20 +688,14 @@ We train the model for a total of 300 epochs and start to collect EMA averages i ```{eval-rst} -.. py:module:: torch.optim.adadelta -.. py:module:: torch.optim.adagrad -.. py:module:: torch.optim.adam -.. py:module:: torch.optim.adamax -.. py:module:: torch.optim.adamw -.. py:module:: torch.optim.asgd -.. py:module:: torch.optim.lbfgs .. py:module:: torch.optim.lr_scheduler -.. py:module:: torch.optim.nadam .. py:module:: torch.optim.optimizer -.. py:module:: torch.optim.radam -.. py:module:: torch.optim.rmsprop -.. py:module:: torch.optim.rprop -.. py:module:: torch.optim.sgd -.. py:module:: torch.optim.sparse_adam .. py:module:: torch.optim.swa_utils ``` + +```{eval-rst} +.. toctree:: + :hidden: + + optim.aliases.md +``` diff --git a/docs/source/package.md b/docs/source/package.md index e337fedde3e6..1b50f743d579 100644 --- a/docs/source/package.md +++ b/docs/source/package.md @@ -416,21 +416,6 @@ with PackageExporter(f2, importer=(importer, sys_importer)) as exporter: exporter.save_pickle("model", "model.pkl", obj) ``` -### Package a TorchScript module? -To package a TorchScript model, use the same `save_pickle` and `load_pickle` APIs as you would with any other object. -Saving TorchScript objects that are attributes or submodules is supported as well with no extra work. - -```python -# save TorchScript just like any other object -with PackageExporter(file_name) as e: - e.save_pickle("res", "script_model.pkl", scripted_model) - e.save_pickle("res", "mixed_model.pkl", python_model_with_scripted_submodule) -# load as normal -importer = PackageImporter(file_name) -loaded_script = importer.load_pickle("res", "script_model.pkl") -loaded_mixed = importer.load_pickle("res", "mixed_model.pkl" -``` - ## Explanation ### `torch.package` Format Overview diff --git a/docs/source/torch.compiler.md b/docs/source/torch.compiler.md index 5f12670f5e1d..4175da896ccf 100644 --- a/docs/source/torch.compiler.md +++ b/docs/source/torch.compiler.md @@ -56,8 +56,6 @@ Some of the most commonly used backends include: - CUDA graphs with AOT Autograd. `Read more `__ * - ``torch.compile(m, backend="ipex")`` - Uses IPEX on CPU. `Read more `__ - * - ``torch.compile(m, backend="onnxrt")`` - - Uses ONNX Runtime for training on CPU/GPU. :doc:`Read more ` ``` **Inference-only backends** diff --git a/docs/source/torch.compiler_aot_inductor.md b/docs/source/torch.compiler_aot_inductor.md index d2a7c9339264..d8514a920848 100644 --- a/docs/source/torch.compiler_aot_inductor.md +++ b/docs/source/torch.compiler_aot_inductor.md @@ -1,3 +1,5 @@ +(torch.compiler_aot_inductor)= + # AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models ```{warning} @@ -25,7 +27,7 @@ relies on. We will then use {func}`torch._inductor.aoti_compile_and_package` to compile the exported program using TorchInductor, and save the compiled artifacts into one -package. +package. The package is in the format of a {ref}`PT2 Archive Spec `. ```{note} If you have a CUDA-enabled device on your machine and you installed PyTorch with CUDA support, diff --git a/docs/source/torch.compiler_ir.md b/docs/source/torch.compiler_ir.md index ed920a064a68..ff66b8cc7efc 100644 --- a/docs/source/torch.compiler_ir.md +++ b/docs/source/torch.compiler_ir.md @@ -1,3 +1,5 @@ +(torch.compiler_ir)= + # IRs PyTorch 2.0 offers two set of IRs for backends to interface with: Core Aten IR and Prims IR. diff --git a/functorch/README.md b/functorch/README.md index 5021c8591cff..5e16966b1daa 100644 --- a/functorch/README.md +++ b/functorch/README.md @@ -7,7 +7,7 @@ | [**Future Plans**](#future-plans) **This library is currently under heavy development - if you have suggestions -on the API or use-cases you'd like to be covered, please open an github issue +on the API or use-cases you'd like to be covered, please open a GitHub issue or reach out. We'd love to hear about how you're using the library.** `functorch` is [JAX-like](https://github.com/google/jax) composable function @@ -161,7 +161,7 @@ result = vmap(model)(examples) ### grad -`grad(func)(*inputs)` assumes `func` returns a single-element Tensor. It compute +`grad(func)(*inputs)` assumes `func` returns a single-element Tensor. It computes the gradients of the output of func w.r.t. to `inputs[0]`. ```py @@ -192,7 +192,7 @@ def compute_loss(weights, example, target): weights = torch.randn(feature_size, requires_grad=True) examples = torch.randn(batch_size, feature_size) targets = torch.randn(batch_size) -inputs = (weights,examples, targets) +inputs = (weights, examples, targets) grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) ``` diff --git a/functorch/writing_batching_rules.md b/functorch/writing_batching_rules.md index 8643614acb55..61872c8d5232 100644 --- a/functorch/writing_batching_rules.md +++ b/functorch/writing_batching_rules.md @@ -5,7 +5,7 @@ First off, what are batching rules and why do we need so many of them? Well, to ### How does vmap work? Vmap is a function transform (pioneered by Jax) that allows one to batch functions. That is, given a function `f(x: [N]) -> [N]`, `vmap(f)` now transforms the signature to be `f(x: [B, N]) -> [B, N]`. That is - it adds a batch dimension to both the input and the output of the function. -This guide will gloss over all the cool things you can do this (there are many!), so let's focus on how we actually implement this. +This guide will gloss over all the cool things you can do with this (there are many!), so let's focus on how we actually implement this. One misconception is that this is some magic compiler voodoo, or that it is inherently some function transform. It is not - and there's another framing of it that might make it more clear. diff --git a/pyproject.toml b/pyproject.toml index b41ae87621f0..523fed351b5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ build-backend = "setuptools.build_meta" name = "torch" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" readme = "README.md" -requires-python = ">=3.9,<3.14" +requires-python = ">=3.9" # TODO: change to `license = "BSD-3-Clause"` and enable PEP 639 after pinning setuptools>=77 # FIXME: As of 2025.06.20, it is hard to ensure the minimum version of setuptools in our CI environment. # TOML-table-based license deprecated in setuptools>=77, and the deprecation warning will be changed diff --git a/requirements-build.txt b/requirements-build.txt new file mode 100644 index 000000000000..be19d987f73d --- /dev/null +++ b/requirements-build.txt @@ -0,0 +1,10 @@ +# Build System requirements +setuptools>=70.1.0,<80.0 # setuptools develop deprecated on 80.0 +cmake>=3.27 +ninja +numpy +packaging +pyyaml +requests +six # dependency chain: NNPACK -> PeachPy -> six +typing-extensions>=4.10.0 diff --git a/requirements.txt b/requirements.txt index 4526f303c046..2affc4d2215a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,25 +1,17 @@ # Python dependencies required for development # Build System requirements -setuptools>=70.1.0,<80.0 # setuptools develop deprecated on 80.0 -cmake>=3.27 -ninja -numpy -packaging -pyyaml -requests -six # dependency chain: NNPACK -> PeachPy -> six -typing-extensions>=4.10.0 +--requirement requirements-build.txt # Install / Development extra requirements build[uv] # for building sdist and wheel expecttest>=0.3.0 filelock -fsspec +fsspec>=0.8.5 hypothesis jinja2 lintrunner ; platform_machine != "s390x" -networkx +networkx>=2.5.1 optree>=0.13.0 psutil sympy>=1.13.3 diff --git a/scripts/README.md b/scripts/README.md index a1c5ae5f93e6..367e7261f6a6 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -1,40 +1 @@ This directory contains the useful tools. - - -## build_android.sh -This script is to build PyTorch/Caffe2 library for Android. Take the following steps to start the build: - -- set ANDROID_NDK to the location of ndk - -```bash -export ANDROID_NDK=YOUR_NDK_PATH -``` - -- run build_android.sh -```bash -#in your PyTorch root directory -bash scripts/build_android.sh -``` -If succeeded, the libraries and headers would be generated to build_android/install directory. You can then copy these files from build_android/install to your Android project for further usage. - -You can also override the cmake flags via command line, e.g., following command will also compile the executable binary files: -```bash -bash scripts/build_android.sh -DBUILD_BINARY=ON -``` - -## build_ios.sh -This script is to build PyTorch/Caffe2 library for iOS, and can only be performed on macOS. Take the following steps to start the build: - -- Install Xcode from App Store, and configure "Command Line Tools" properly on Xcode. -- Install the dependencies: - -```bash -brew install cmake automake libtool -``` - -- run build_ios.sh -```bash -#in your PyTorch root directory -bash scripts/build_ios.sh -``` -If succeeded, the libraries and headers would be generated to build_ios/install directory. You can then copy these files to your Xcode project for further usage. diff --git a/scripts/add_apache_header.sh b/scripts/add_apache_header.sh deleted file mode 100755 index a29a059d2d03..000000000000 --- a/scripts/add_apache_header.sh +++ /dev/null @@ -1 +0,0 @@ -cat apache_header.txt $1 > _add_apache_header.txt && mv _add_apache_header.txt $1 diff --git a/scripts/apache_header.txt b/scripts/apache_header.txt deleted file mode 100644 index b4eff258eb04..000000000000 --- a/scripts/apache_header.txt +++ /dev/null @@ -1,15 +0,0 @@ -/** - * Copyright (c) 2016-present, Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ diff --git a/scripts/apache_python.txt b/scripts/apache_python.txt deleted file mode 100644 index bc104d884515..000000000000 --- a/scripts/apache_python.txt +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) 2016-present, Facebook, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -############################################################################## diff --git a/scripts/build_android.sh b/scripts/build_android.sh deleted file mode 100755 index 43f11b86828d..000000000000 --- a/scripts/build_android.sh +++ /dev/null @@ -1,189 +0,0 @@ -#!/bin/bash -############################################################################## -# Example command to build the android target. -############################################################################## -# -# This script shows how one can build a Caffe2 binary for the Android platform -# using android-cmake. A few notes: -# -# (1) This build also does a host build for protobuf. You will need autoconf -# to carry out this. If autoconf is not possible, you will need to provide -# a pre-built protoc binary that is the same version as the protobuf -# version under third_party. -# If you are building on Mac, you might need to install autotool and -# libtool. The easiest way is via homebrew: -# brew install automake -# brew install libtool -# (2) You will need to have android ndk installed. The current script assumes -# that you set ANDROID_NDK to the location of ndk. -# (3) The toolchain and the build target platform can be specified with the -# cmake arguments below. For more details, check out android-cmake's doc. - -set -e - -# Android specific flags -if [ -z "$ANDROID_ABI" ]; then - ANDROID_ABI="armeabi-v7a with NEON" -fi -ANDROID_NATIVE_API_LEVEL="21" -echo "Build with ANDROID_ABI[$ANDROID_ABI], ANDROID_NATIVE_API_LEVEL[$ANDROID_NATIVE_API_LEVEL]" - -CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" -if [ -z "$ANDROID_NDK" ]; then - echo "ANDROID_NDK not set; please set it to the Android NDK directory" - exit 1 -fi - -if [ ! -d "$ANDROID_NDK" ]; then - echo "ANDROID_NDK not a directory; did you install it under $ANDROID_NDK?" - exit 1 -fi - -if [ -z "$PYTHON" ]; then - PYTHON=python - PYTHON_VERSION_MAJOR=$($PYTHON -c 'import sys; print(sys.version_info[0])') - if [ "${PYTHON_VERSION_MAJOR}" -le 2 ]; then - echo "Default python executable is Python-2, trying to use python3 alias" - PYTHON=python3 - fi -fi - -ANDROID_NDK_PROPERTIES="$ANDROID_NDK/source.properties" -[ -f "$ANDROID_NDK_PROPERTIES" ] && ANDROID_NDK_VERSION=$(sed -n 's/^Pkg.Revision[^=]*= *\([0-9]*\)\..*$/\1/p' "$ANDROID_NDK_PROPERTIES") - -echo "Bash: $(/bin/bash --version | head -1)" -echo "Python: $($PYTHON -c 'import sys; print(sys.version)')" -echo "Caffe2 path: $CAFFE2_ROOT" -echo "Using Android NDK at $ANDROID_NDK" -echo "Android NDK version: $ANDROID_NDK_VERSION" - -CMAKE_ARGS=() - -# Build PyTorch mobile -CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$($PYTHON -c 'import sysconfig; print(sysconfig.get_path("purelib"))')") -CMAKE_ARGS+=("-DPython_EXECUTABLE=$($PYTHON -c 'import sys; print(sys.executable)')") -CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") - -# custom build with selected ops -if [ -n "${SELECTED_OP_LIST}" ]; then - SELECTED_OP_LIST="$(cd $(dirname $SELECTED_OP_LIST); pwd -P)/$(basename $SELECTED_OP_LIST)" - echo "Choose SELECTED_OP_LIST file: $SELECTED_OP_LIST" - if [ ! -r ${SELECTED_OP_LIST} ]; then - echo "Error: SELECTED_OP_LIST file ${SELECTED_OP_LIST} not found." - exit 1 - fi - CMAKE_ARGS+=("-DSELECTED_OP_LIST=${SELECTED_OP_LIST}") -fi - -# If Ninja is installed, prefer it to Make -if [ -x "$(command -v ninja)" ]; then - CMAKE_ARGS+=("-GNinja") -fi - -# Use android-cmake to build Android project from CMake. -CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake") - -if [ -z "$BUILD_MOBILE_BENCHMARK" ]; then - BUILD_MOBILE_BENCHMARK=0 -fi - -if [ -z "$BUILD_MOBILE_TEST" ]; then - BUILD_MOBILE_TEST=0 -fi -# Don't build artifacts we don't need -CMAKE_ARGS+=("-DBUILD_TEST=OFF") -CMAKE_ARGS+=("-DBUILD_BINARY=OFF") - -# If there exists env variable and it equals to 0, build full jit interpreter. -# Default behavior is to build lite interpreter -# cmd: BUILD_LITE_INTERPRETER=0 ./scripts/build_android.sh -if [ "${BUILD_LITE_INTERPRETER}" == 0 ]; then - CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF") -else - CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON") -fi -if [ "${TRACING_BASED}" == 1 ]; then - CMAKE_ARGS+=("-DTRACING_BASED=ON") -else - CMAKE_ARGS+=("-DTRACING_BASED=OFF") -fi -if [ "${USE_LIGHTWEIGHT_DISPATCH}" == 1 ]; then - CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=ON") - CMAKE_ARGS+=("-DSTATIC_DISPATCH_BACKEND=CPU") -else - CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=OFF") -fi - -CMAKE_ARGS+=("-DBUILD_MOBILE_BENCHMARK=$BUILD_MOBILE_BENCHMARK") -CMAKE_ARGS+=("-DBUILD_MOBILE_TEST=$BUILD_MOBILE_TEST") -CMAKE_ARGS+=("-DBUILD_PYTHON=OFF") -CMAKE_ARGS+=("-DBUILD_SHARED_LIBS=OFF") -if (( "${ANDROID_NDK_VERSION:-0}" < 18 )); then - CMAKE_ARGS+=("-DANDROID_TOOLCHAIN=gcc") -else - CMAKE_ARGS+=("-DANDROID_TOOLCHAIN=clang") -fi -# Disable unused dependencies -CMAKE_ARGS+=("-DUSE_CUDA=OFF") -CMAKE_ARGS+=("-DUSE_ITT=OFF") -CMAKE_ARGS+=("-DUSE_GFLAGS=OFF") -CMAKE_ARGS+=("-DUSE_OPENCV=OFF") -CMAKE_ARGS+=("-DUSE_MPI=OFF") -CMAKE_ARGS+=("-DUSE_OPENMP=OFF") -# Only toggle if VERBOSE=1 -if [ "${VERBOSE:-}" == '1' ]; then - CMAKE_ARGS+=("-DCMAKE_VERBOSE_MAKEFILE=1") -fi - -# Android specific flags -CMAKE_ARGS+=("-DANDROID_NDK=$ANDROID_NDK") -CMAKE_ARGS+=("-DANDROID_ABI=$ANDROID_ABI") -CMAKE_ARGS+=("-DANDROID_NATIVE_API_LEVEL=$ANDROID_NATIVE_API_LEVEL") -CMAKE_ARGS+=("-DANDROID_CPP_FEATURES=rtti exceptions") -if [ "${ANDROID_STL_SHARED:-}" == '1' ]; then - CMAKE_ARGS+=("-DANDROID_STL=c++_shared") -fi -if [ "${ANDROID_DEBUG_SYMBOLS:-}" == '1' ]; then - CMAKE_ARGS+=("-DANDROID_DEBUG_SYMBOLS=1") -fi - -if [ -n "${USE_VULKAN}" ]; then - CMAKE_ARGS+=("-DUSE_VULKAN=ON") - if [ -n "${USE_VULKAN_FP16_INFERENCE}" ]; then - CMAKE_ARGS+=("-DUSE_VULKAN_FP16_INFERENCE=ON") - fi - if [ -n "${USE_VULKAN_RELAXED_PRECISION}" ]; then - CMAKE_ARGS+=("-DUSE_VULKAN_RELAXED_PRECISION=ON") - fi -fi - -# Use-specified CMake arguments go last to allow overriding defaults -CMAKE_ARGS+=($@) - -# Patch pocketfft (as Android does not have aligned_alloc even if compiled with c++17 -if [ -f third_party/pocketfft/pocketfft_hdronly.h ]; then - sed -i -e "s/__cplusplus >= 201703L/0/" third_party/pocketfft/pocketfft_hdronly.h -fi - -# Now, actually build the Android target. -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_android"} -INSTALL_PREFIX=${BUILD_ROOT}/install -mkdir -p $BUILD_ROOT -cd $BUILD_ROOT -cmake "$CAFFE2_ROOT" \ - -DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \ - -DCMAKE_BUILD_TYPE=Release \ - "${CMAKE_ARGS[@]}" - -# Cross-platform parallel build -if [ -z "$MAX_JOBS" ]; then - if [ "$(uname)" == 'Darwin' ]; then - MAX_JOBS=$(sysctl -n hw.ncpu) - else - MAX_JOBS=$(nproc) - fi -fi - -echo "Will install headers and libs to $INSTALL_PREFIX for further Android project usage." -cmake --build . --target install -- "-j${MAX_JOBS}" -echo "Installation completed, now you can copy the headers/libs from $INSTALL_PREFIX to your Android project directory." diff --git a/scripts/build_android_gradle.sh b/scripts/build_android_gradle.sh deleted file mode 100755 index fc27c5dd2516..000000000000 --- a/scripts/build_android_gradle.sh +++ /dev/null @@ -1,102 +0,0 @@ -#!/usr/bin/env bash -set -eux -o pipefail - -env -echo "BUILD_ENVIRONMENT:$BUILD_ENVIRONMENT" - -export ANDROID_NDK_HOME=/opt/ndk -export ANDROID_NDK=/opt/ndk -export ANDROID_HOME=/opt/android/sdk - -# Must be in sync with GRADLE_VERSION in docker image for android -# https://github.com/pietern/pytorch-dockerfiles/blob/master/build.sh#L155 -export GRADLE_VERSION=6.8.3 -export GRADLE_HOME=/opt/gradle/gradle-$GRADLE_VERSION -export GRADLE_PATH=$GRADLE_HOME/bin/gradle - -# touch gradle cache files to prevent expiration -while IFS= read -r -d '' file -do - touch "$file" || true -done < <(find /var/lib/jenkins/.gradle -type f -print0) - -# Patch pocketfft (as Android does not have aligned_alloc even if compiled with c++17 -if [ -f ~/workspace/third_party/pocketfft/pocketfft_hdronly.h ]; then - sed -i -e "s/__cplusplus >= 201703L/0/" ~/workspace/third_party/pocketfft/pocketfft_hdronly.h -fi - -export GRADLE_LOCAL_PROPERTIES=~/workspace/android/local.properties -rm -f $GRADLE_LOCAL_PROPERTIES -echo "sdk.dir=/opt/android/sdk" >> $GRADLE_LOCAL_PROPERTIES -echo "ndk.dir=/opt/ndk" >> $GRADLE_LOCAL_PROPERTIES -echo "cmake.dir=/usr/local" >> $GRADLE_LOCAL_PROPERTIES - -retry () { - $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) -} - -# Run custom build script -if [[ "${BUILD_ENVIRONMENT}" == *-gradle-custom-build* ]]; then - # Install torch & torchvision - used to download & dump used ops from test model. - retry pip install torch torchvision --progress-bar off - - exec "$(dirname "${BASH_SOURCE[0]}")/../android/build_test_app_custom.sh" armeabi-v7a -fi - -# Run default build -BUILD_ANDROID_INCLUDE_DIR_x86=~/workspace/build_android/install/include -BUILD_ANDROID_LIB_DIR_x86=~/workspace/build_android/install/lib - -BUILD_ANDROID_INCLUDE_DIR_x86_64=~/workspace/build_android_install_x86_64/install/include -BUILD_ANDROID_LIB_DIR_x86_64=~/workspace/build_android_install_x86_64/install/lib - -BUILD_ANDROID_INCLUDE_DIR_arm_v7a=~/workspace/build_android_install_arm_v7a/install/include -BUILD_ANDROID_LIB_DIR_arm_v7a=~/workspace/build_android_install_arm_v7a/install/lib - -BUILD_ANDROID_INCLUDE_DIR_arm_v8a=~/workspace/build_android_install_arm_v8a/install/include -BUILD_ANDROID_LIB_DIR_arm_v8a=~/workspace/build_android_install_arm_v8a/install/lib - -PYTORCH_ANDROID_SRC_MAIN_DIR=~/workspace/android/pytorch_android/src/main - -JNI_INCLUDE_DIR=${PYTORCH_ANDROID_SRC_MAIN_DIR}/cpp/libtorch_include -mkdir -p $JNI_INCLUDE_DIR - -JNI_LIBS_DIR=${PYTORCH_ANDROID_SRC_MAIN_DIR}/jniLibs -mkdir -p $JNI_LIBS_DIR - -ln -s ${BUILD_ANDROID_INCLUDE_DIR_x86} ${JNI_INCLUDE_DIR}/x86 -ln -s ${BUILD_ANDROID_LIB_DIR_x86} ${JNI_LIBS_DIR}/x86 - -if [[ "${BUILD_ENVIRONMENT}" != *-gradle-build-only-x86_32* ]]; then -ln -s ${BUILD_ANDROID_INCLUDE_DIR_x86_64} ${JNI_INCLUDE_DIR}/x86_64 -ln -s ${BUILD_ANDROID_LIB_DIR_x86_64} ${JNI_LIBS_DIR}/x86_64 - -ln -s ${BUILD_ANDROID_INCLUDE_DIR_arm_v7a} ${JNI_INCLUDE_DIR}/armeabi-v7a -ln -s ${BUILD_ANDROID_LIB_DIR_arm_v7a} ${JNI_LIBS_DIR}/armeabi-v7a - -ln -s ${BUILD_ANDROID_INCLUDE_DIR_arm_v8a} ${JNI_INCLUDE_DIR}/arm64-v8a -ln -s ${BUILD_ANDROID_LIB_DIR_arm_v8a} ${JNI_LIBS_DIR}/arm64-v8a -fi - -GRADLE_PARAMS="-p android assembleRelease --debug --stacktrace" -if [[ "${BUILD_ENVIRONMENT}" == *-gradle-build-only-x86_32* ]]; then - GRADLE_PARAMS+=" -PABI_FILTERS=x86" -fi - -if [ -n "${GRADLE_OFFLINE:-}" ]; then - GRADLE_PARAMS+=" --offline" -fi - -$GRADLE_PATH $GRADLE_PARAMS - -find . -type f -name "*.a" -exec ls -lh {} \; - -while IFS= read -r -d '' file -do - echo - echo "$file" - ls -lah "$file" - zipinfo -l "$file" -done < <(find . -type f -name '*.aar' -print0) - -find . -type f -name *aar -print | xargs tar cfvz ~/workspace/android/artifacts.tgz diff --git a/scripts/build_host_protoc.sh b/scripts/build_host_protoc.sh deleted file mode 100755 index cd37db3b3171..000000000000 --- a/scripts/build_host_protoc.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/bin/bash -############################################################################## -# Build script to build the protoc compiler for the host platform. -############################################################################## -# This script builds the protoc compiler for the host platform, which is needed -# for any cross-compilation as we will need to convert the protobuf source -# files to cc files. -# -# --other-flags accepts flags that should be passed to cmake. Optional. -# -# After the execution of the file, one should be able to find the host protoc -# binary at build_host_protoc/bin/protoc. - -set -e - -CAFFE2_ROOT="$( cd "$(dirname -- "$0")"/.. ; pwd -P)" -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_host_protoc"} -mkdir -p $BUILD_ROOT/build -cd $BUILD_ROOT/build - -CMAKE_ARGS=() -CMAKE_ARGS+=("-DCMAKE_INSTALL_PREFIX=$BUILD_ROOT") -CMAKE_ARGS+=("-Dprotobuf_BUILD_TESTS=OFF") - -# If Ninja is installed, prefer it to Make -if [ -x "$(command -v ninja)" ]; then - CMAKE_ARGS+=("-GNinja") -fi - -while true; do - case "$1" in - --other-flags) - shift; - CMAKE_ARGS+=("$@") - break ;; - "") - break ;; - *) - echo "Unknown option passed as argument: $1" - break ;; - esac -done - -# Use ccache if available (this path is where Homebrew installs ccache symlinks) -if [ "$(uname)" == 'Darwin' ] && [ -d /usr/local/opt/ccache/libexec ]; then - CMAKE_ARGS+=("-DCMAKE_C_COMPILER=/usr/local/opt/ccache/libexec/gcc") - CMAKE_ARGS+=("-DCMAKE_CXX_COMPILER=/usr/local/opt/ccache/libexec/g++") -fi - -cmake "$CAFFE2_ROOT/third_party/protobuf/cmake" ${CMAKE_ARGS[@]} - -if [ -z "$MAX_JOBS" ]; then - if [ "$(uname)" == 'Darwin' ]; then - MAX_JOBS=$(sysctl -n hw.ncpu) - else - MAX_JOBS=$(nproc) - fi -fi -cmake --build . -- "-j${MAX_JOBS}" install diff --git a/scripts/build_ios.sh b/scripts/build_ios.sh deleted file mode 100755 index ad16cb940dcb..000000000000 --- a/scripts/build_ios.sh +++ /dev/null @@ -1,155 +0,0 @@ -#!/bin/bash -xe -############################################################################## -# Example command to build the iOS target. -############################################################################## -# -# This script shows how one can build a Caffe2 binary for the iOS platform -# using ios-cmake. This is very similar to the android-cmake - see -# build_android.sh for more details. - -CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" - -if [ -z "$PYTHON" ]; then - PYTHON=python - PYTHON_VERSION_MAJOR=$($PYTHON -c 'import sys; print(sys.version_info[0])') - if [ "${PYTHON_VERSION_MAJOR}" -le 2 ]; then - echo "Default python executable is Python-2, trying to use python3 alias" - PYTHON=python3 - fi -fi - -echo "Bash: $(/bin/bash --version | head -1)" -echo "Python: $($PYTHON -c 'import sys; print(sys.version)')" -echo "Caffe2 path: $CAFFE2_ROOT" - -CMAKE_ARGS=() - -# Build PyTorch mobile -CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$($PYTHON -c 'import sysconfig; print(sysconfig.get_path("purelib"))')") -CMAKE_ARGS+=("-DPython_EXECUTABLE=$($PYTHON -c 'import sys; print(sys.executable)')") -CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") - -# custom build with selected ops -if [ -n "${SELECTED_OP_LIST}" ]; then - SELECTED_OP_LIST="$(cd $(dirname $SELECTED_OP_LIST); pwd -P)/$(basename $SELECTED_OP_LIST)" - echo "Choose SELECTED_OP_LIST file: $SELECTED_OP_LIST" - if [ ! -r ${SELECTED_OP_LIST} ]; then - echo "Error: SELECTED_OP_LIST file ${SELECTED_OP_LIST} not found." - exit 1 - fi - CMAKE_ARGS+=("-DSELECTED_OP_LIST=${SELECTED_OP_LIST}") -fi - -# bitcode -if [ "${ENABLE_BITCODE:-}" == '1' ]; then - CMAKE_ARGS+=("-DCMAKE_C_FLAGS=-fembed-bitcode") - CMAKE_ARGS+=("-DCMAKE_CXX_FLAGS=-fembed-bitcode") -fi - -# Use ios-cmake to build iOS project from CMake. -# This projects sets CMAKE_C_COMPILER to /usr/bin/gcc and -# CMAKE_CXX_COMPILER to /usr/bin/g++. In order to use ccache (if it is available) we -# must override these variables via CMake arguments. -CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$CAFFE2_ROOT/cmake/iOS.cmake") -if [ -n "${CCACHE_WRAPPER_PATH:-}"]; then - CCACHE_WRAPPER_PATH=/usr/local/opt/ccache/libexec -fi -if [ -d "$CCACHE_WRAPPER_PATH" ]; then - CMAKE_ARGS+=("-DCMAKE_C_COMPILER=$CCACHE_WRAPPER_PATH/gcc") - CMAKE_ARGS+=("-DCMAKE_CXX_COMPILER=$CCACHE_WRAPPER_PATH/g++") -fi - -# IOS_PLATFORM controls type of iOS platform (see ios-cmake) -if [ -n "${IOS_PLATFORM:-}" ]; then - CMAKE_ARGS+=("-DIOS_PLATFORM=${IOS_PLATFORM}") - if [ "${IOS_PLATFORM}" == "WATCHOS" ]; then - # enable bitcode by default for watchos - CMAKE_ARGS+=("-DCMAKE_C_FLAGS=-fembed-bitcode") - CMAKE_ARGS+=("-DCMAKE_CXX_FLAGS=-fembed-bitcode") - # disable the QNNPACK - CMAKE_ARGS+=("-DUSE_PYTORCH_QNNPACK=OFF") - fi -else - # IOS_PLATFORM is not set, default to OS, which builds iOS. - CMAKE_ARGS+=("-DIOS_PLATFORM=OS") -fi - -if [ -n "${IOS_ARCH:-}" ]; then - CMAKE_ARGS+=("-DIOS_ARCH=${IOS_ARCH}") -fi - -if [ "${BUILD_LITE_INTERPRETER}" == 0 ]; then - CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF") -else - CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON") -fi -if [ "${TRACING_BASED}" == 1 ]; then - CMAKE_ARGS+=("-DTRACING_BASED=ON") -else - CMAKE_ARGS+=("-DTRACING_BASED=OFF") -fi -if [ "${USE_LIGHTWEIGHT_DISPATCH}" == 1 ]; then - CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=ON") - CMAKE_ARGS+=("-DSTATIC_DISPATCH_BACKEND=CPU") -else - CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=OFF") -fi - -CMAKE_ARGS+=("-DUSE_LITE_INTERPRETER_PROFILER=OFF") - -# Don't build binaries or tests (only the library) -CMAKE_ARGS+=("-DBUILD_TEST=OFF") -CMAKE_ARGS+=("-DBUILD_BINARY=OFF") -CMAKE_ARGS+=("-DBUILD_PYTHON=OFF") - -# Disable unused dependencies -CMAKE_ARGS+=("-DUSE_CUDA=OFF") -CMAKE_ARGS+=("-DUSE_ITT=OFF") -CMAKE_ARGS+=("-DUSE_GFLAGS=OFF") -CMAKE_ARGS+=("-DUSE_OPENCV=OFF") -CMAKE_ARGS+=("-DUSE_MPI=OFF") -CMAKE_ARGS+=("-DUSE_NUMPY=OFF") -CMAKE_ARGS+=("-DUSE_NNPACK=OFF") -CMAKE_ARGS+=("-DUSE_MKLDNN=OFF") - -# Metal -if [ "${USE_PYTORCH_METAL:-}" == "1" ]; then - CMAKE_ARGS+=("-DUSE_PYTORCH_METAL=ON") -fi - -# Core ML -if [ "${USE_COREML_DELEGATE}" == "1" ]; then - CMAKE_ARGS+=("-DUSE_COREML_DELEGATE=ON") -fi - -# pthreads -CMAKE_ARGS+=("-DCMAKE_THREAD_LIBS_INIT=-lpthread") -CMAKE_ARGS+=("-DCMAKE_HAVE_THREADS_LIBRARY=1") -CMAKE_ARGS+=("-DCMAKE_USE_PTHREADS_INIT=1") - -# Only toggle if VERBOSE=1 -if [ "${VERBOSE:-}" == '1' ]; then - CMAKE_ARGS+=("-DCMAKE_VERBOSE_MAKEFILE=1") -fi - -# enable ARC -CMAKE_ARGS+=("-DCMAKE_CXX_FLAGS=-fobjc-arc") - -# Now, actually build the iOS target. -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_ios"} -INSTALL_PREFIX=${BUILD_ROOT}/install -mkdir -p $BUILD_ROOT -cd $BUILD_ROOT -cmake "$CAFFE2_ROOT" \ - -DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \ - -DCMAKE_BUILD_TYPE=MinSizeRel \ - -DBUILD_SHARED_LIBS=OFF \ - ${CMAKE_ARGS[@]} \ - $@ - -cmake --build . -- "-j$(sysctl -n hw.ncpu)" - -# copy headers and libs to install directory -echo "Will install headers and libs to $INSTALL_PREFIX for further Xcode project usage." -make install -echo "Installation completed, now you can copy the headers/libs from $INSTALL_PREFIX to your Xcode project directory." diff --git a/scripts/build_local.sh b/scripts/build_local.sh deleted file mode 100755 index b84367150125..000000000000 --- a/scripts/build_local.sh +++ /dev/null @@ -1,82 +0,0 @@ -#!/bin/bash -# -############################################################################## -# Example command to build Caffe2 -############################################################################## -# - -set -ex - -CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" - -CMAKE_ARGS=() - -# If Ninja is installed, prefer it to Make -if [ -x "$(command -v ninja)" ]; then - CMAKE_ARGS+=("-GNinja") -fi - -# Use ccache if available (this path is where Homebrew installs ccache symlinks) -if [ "$(uname)" == 'Darwin' ]; then - if [ -n "${CCACHE_WRAPPER_PATH:-}"]; then - CCACHE_WRAPPER_PATH=/usr/local/opt/ccache/libexec - fi - if [ -d "$CCACHE_WRAPPER_PATH" ]; then - CMAKE_ARGS+=("-DCMAKE_C_COMPILER=$CCACHE_WRAPPER_PATH/gcc") - CMAKE_ARGS+=("-DCMAKE_CXX_COMPILER=$CCACHE_WRAPPER_PATH/g++") - fi -fi - -# Use special install script with Anaconda -if [ -n "${USE_ANACONDA}" ]; then - export SKIP_CONDA_TESTS=1 - export CONDA_INSTALL_LOCALLY=1 - "${ROOT_DIR}/scripts/build_anaconda.sh" "$@" -else - # Make sure that pyyaml is installed for the codegen of building Aten to work - if [[ -n "$(python -c 'import yaml' 2>&1)" ]]; then - echo "Installing pyyaml with pip at $(which pip)" - pip install --user pyyaml - fi - - # Make sure that typing is installed for the codegen of building Aten to work - if [[ -n "$(python -c 'import typing' 2>&1)" ]]; then - echo "Installing typing with pip at $(which pip)" - pip install --user typing - fi - - # Build protobuf compiler from third_party if configured to do so - if [ -n "${USE_HOST_PROTOC:-}" ]; then - echo "USE_HOST_PROTOC is set; building protoc before building Caffe2..." - "$CAFFE2_ROOT/scripts/build_host_protoc.sh" - CUSTOM_PROTOC_EXECUTABLE="$CAFFE2_ROOT/build_host_protoc/bin/protoc" - echo "Built protoc $("$CUSTOM_PROTOC_EXECUTABLE" --version)" - CMAKE_ARGS+=("-DCAFFE2_CUSTOM_PROTOC_EXECUTABLE=$CUSTOM_PROTOC_EXECUTABLE") - fi - - # We are going to build the target into build. - BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} - mkdir -p "$BUILD_ROOT" - cd "$BUILD_ROOT" - echo "Building Caffe2 in: $BUILD_ROOT" - - cmake "$CAFFE2_ROOT" \ - -DCMAKE_BUILD_TYPE=Release \ - "${CMAKE_ARGS[@]}" \ - "$@" - - # Determine the number of CPUs to build with. - # If the `CAFFE_MAKE_NCPUS` variable is not specified, use them all. - if [ -n "${MAX_JOBS}" ]; then - CAFFE_MAKE_NCPUS="$MAX_JOBS" - elif [ -n "${CAFFE_MAKE_NCPUS}" ]; then - CAFFE_MAKE_NCPUS="$CAFFE_MAKE_NCPUS" - elif [ "$(uname)" == 'Darwin' ]; then - CAFFE_MAKE_NCPUS="$(sysctl -n hw.ncpu)" - else - CAFFE_MAKE_NCPUS="$(nproc)" - fi - - # Now, actually build the target. - cmake --build . -- "-j$CAFFE_MAKE_NCPUS" -fi diff --git a/scripts/build_mobile.sh b/scripts/build_mobile.sh deleted file mode 100755 index 7b1995a61ebc..000000000000 --- a/scripts/build_mobile.sh +++ /dev/null @@ -1,107 +0,0 @@ -#!/bin/bash -############################################################################## -# Example command to build the mobile target. -############################################################################## -# -# This script shows how one can build a libtorch library optimized for mobile -# devices using host toolchain. - -set -e - -export BUILD_PYTORCH_MOBILE_WITH_HOST_TOOLCHAIN=1 -CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" - -echo "Bash: $(/bin/bash --version | head -1)" -echo "Caffe2 path: $CAFFE2_ROOT" - -CMAKE_ARGS=() -CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$(python -c 'import sysconfig; print(sysconfig.get_path("purelib"))')") -CMAKE_ARGS+=("-DPython_EXECUTABLE=$(python -c 'import sys; print(sys.executable)')") -CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") -CMAKE_ARGS+=("-DBUILD_SHARED_LIBS=OFF") - -# custom build with selected ops -if [ -n "${SELECTED_OP_LIST}" ]; then - SELECTED_OP_LIST="$(cd $(dirname $SELECTED_OP_LIST); pwd -P)/$(basename $SELECTED_OP_LIST)" - echo "Choose SELECTED_OP_LIST file: $SELECTED_OP_LIST" - if [ ! -r ${SELECTED_OP_LIST} ]; then - echo "Error: SELECTED_OP_LIST file ${SELECTED_OP_LIST} not found." - exit 1 - fi - CMAKE_ARGS+=("-DSELECTED_OP_LIST=${SELECTED_OP_LIST}") -fi - -# If Ninja is installed, prefer it to Make -if [ -x "$(command -v ninja)" ]; then - CMAKE_ARGS+=("-GNinja") -fi - -# Don't build artifacts we don't need -CMAKE_ARGS+=("-DBUILD_TEST=OFF") -CMAKE_ARGS+=("-DBUILD_BINARY=OFF") - -# If there exists env variable and it equals to 1, build lite interpreter. -# Default behavior is to build full jit interpreter. -# cmd: BUILD_LITE_INTERPRETER=1 ./scripts/build_mobile.sh -if [ "x${BUILD_LITE_INTERPRETER}" == "x1" ]; then - CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=ON") -else - CMAKE_ARGS+=("-DBUILD_LITE_INTERPRETER=OFF") -fi -if [ "x${TRACING_BASED}" == "x1" ]; then - CMAKE_ARGS+=("-DTRACING_BASED=ON") -else - CMAKE_ARGS+=("-DTRACING_BASED=OFF") -fi - -# Lightweight dispatch bypasses the PyTorch Dispatcher. -if [ "${USE_LIGHTWEIGHT_DISPATCH}" == 1 ]; then - CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=ON") - CMAKE_ARGS+=("-DSTATIC_DISPATCH_BACKEND=CPU") -else - CMAKE_ARGS+=("-DUSE_LIGHTWEIGHT_DISPATCH=OFF") -fi - -# Disable unused dependencies -CMAKE_ARGS+=("-DUSE_ROCM=OFF") -CMAKE_ARGS+=("-DUSE_CUDA=OFF") -CMAKE_ARGS+=("-DUSE_ITT=OFF") -CMAKE_ARGS+=("-DUSE_GFLAGS=OFF") -CMAKE_ARGS+=("-DUSE_OPENCV=OFF") -CMAKE_ARGS+=("-DUSE_MPI=OFF") -CMAKE_ARGS+=("-DUSE_OPENMP=OFF") -CMAKE_ARGS+=("-DUSE_MKLDNN=OFF") -CMAKE_ARGS+=("-DUSE_NNPACK=OFF") -CMAKE_ARGS+=("-DUSE_NUMPY=OFF") -CMAKE_ARGS+=("-DUSE_BLAS=OFF") - -# Only toggle if VERBOSE=1 -if [ "${VERBOSE:-}" == '1' ]; then - CMAKE_ARGS+=("-DCMAKE_VERBOSE_MAKEFILE=1") -fi - -# Use-specified CMake arguments go last to allow overriding defaults -CMAKE_ARGS+=("$@") - -# Now, actually build the Android target. -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_mobile"} -INSTALL_PREFIX=${BUILD_ROOT}/install -mkdir -p $BUILD_ROOT -cd $BUILD_ROOT -cmake "$CAFFE2_ROOT" \ - -DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \ - -DCMAKE_BUILD_TYPE=Release \ - "${CMAKE_ARGS[@]}" - -# Cross-platform parallel build -if [ -z "$MAX_JOBS" ]; then - if [ "$(uname)" == 'Darwin' ]; then - MAX_JOBS=$(sysctl -n hw.ncpu) - else - MAX_JOBS=$(nproc) - fi -fi - -echo "Will install headers and libs to $INSTALL_PREFIX for further project usage." -cmake --build . --target install -- "-j${MAX_JOBS}" -echo "Installation completed, now you can copy the headers/libs from $INSTALL_PREFIX to your project directory." diff --git a/scripts/build_pytorch_android.sh b/scripts/build_pytorch_android.sh deleted file mode 100755 index 7b80965e34b5..000000000000 --- a/scripts/build_pytorch_android.sh +++ /dev/null @@ -1,51 +0,0 @@ -#!/bin/bash -set -eux - -############################################################################## -# Master script to build PyTorch Android library with Java bindings. -############################################################################## -# Example usage: -# - Build default AARs: -# scripts/build_pytorch_android.sh -# -# - Build for specific ABI(s): -# scripts/build_pytorch_android.sh armeabi-v7a -# scripts/build_pytorch_android.sh arm64-v8a,x86,x86_64 -# -# Script's workflow: -# 1. Builds libtorch for android for specified android abisi (by default for all 4). -# Custom list of android abis can be specified as a bash argument as comma separated list. -# For example just for testing on android x86 emulator we need only x86 build. -# ./scripts/build_pytorch_android.sh x86 -# 2. Creates symbolic links to android/pytorch_android/src/main/jniLibs/${abi} for libtorch build output, -# android/pytorch_android/src/main/cpp/libtorch_include/${abi} for headers. -# 3. Runs pyotrch_android gradle build: -# gradle assembleRelease - -PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)" -PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android - -echo "PYTORCH_DIR:$PYTORCH_DIR" - -source "$PYTORCH_ANDROID_DIR/common.sh" - -check_android_sdk -check_gradle -parse_abis_list "$@" -build_android - -# To set proxy for gradle add following lines to ./gradle/gradle.properties: -# systemProp.http.proxyHost=... -# systemProp.http.proxyPort=8080 -# systemProp.https.proxyHost=... -# systemProp.https.proxyPort=8080 - -if [ "$CUSTOM_ABIS_LIST" = true ]; then - # Skipping clean task here as android gradle plugin 3.3.2 exteralNativeBuild has problems - # with it when abiFilters are specified. - $GRADLE_PATH -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR assembleRelease -else - $GRADLE_PATH -p $PYTORCH_ANDROID_DIR clean assembleRelease -fi - -find $PYTORCH_ANDROID_DIR -type f -name *aar | xargs ls -lah diff --git a/scripts/build_raspbian.sh b/scripts/build_raspbian.sh deleted file mode 100755 index b1fe85926219..000000000000 --- a/scripts/build_raspbian.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/bin/bash -############################################################################## -# Example command to build the Raspbian target. -############################################################################## -# -# This script shows how one can build a Caffe2 binary for raspbian. The build -# is essentially much similar to a host build, with one additional change -# which is to specify -mfpu=neon for optimized speed. - -CAFFE2_ROOT="$( cd "$(dirname -- "$0")"/.. ; pwd -P)" -echo "Caffe2 codebase root is: $CAFFE2_ROOT" -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} -mkdir -p $BUILD_ROOT -echo "Build Caffe2 raspbian into: $BUILD_ROOT" - -# obtain dependencies. -echo "Installing dependencies." -sudo apt-get install \ - cmake \ - libgflags-dev \ - libgoogle-glog-dev \ - libprotobuf-dev \ - libpython-dev \ - python-pip \ - python-numpy \ - protobuf-compiler \ - python-protobuf -# python dependencies -sudo pip install hypothesis - -# Now, actually build the raspbian target. -echo "Building caffe2" -cd $BUILD_ROOT - -# Note: you can add more dependencies above if you need libraries such as -# leveldb, lmdb, etc. -cmake "$CAFFE2_ROOT" \ - -DCMAKE_VERBOSE_MAKEFILE=1 \ - -DCAFFE2_CPU_FLAGS="-mfpu=neon -mfloat-abi=hard" \ - || exit 1 - -# Note: while Raspberry pi has 4 cores, running too many builds in parallel may -# cause out of memory errors so we will simply run -j 2 only. -make -j 2 || exit 1 diff --git a/scripts/build_tegra_x1.sh b/scripts/build_tegra_x1.sh deleted file mode 100755 index 063e17dfe351..000000000000 --- a/scripts/build_tegra_x1.sh +++ /dev/null @@ -1,51 +0,0 @@ -#!/bin/bash -############################################################################## -# Example command to build Caffe2 on Tegra X1. -############################################################################## -# -# This script shows how one can build a Caffe2 binary for NVidia's TX1. -# The build script assumes that you have the most recent libraries installed -# via the JetPack toolkit available at -# https://developer.nvidia.com/embedded/jetpack -# and it assumes that we are starting from a fresh system after the jetpack -# installation. If you have already installed some of the dependencies, you -# may be able to skip quite a few of the apt-get installs. - -CAFFE2_ROOT="$( cd "$(dirname -- "$0")"/.. ; pwd -P)" -echo "Caffe2 codebase root is: $CAFFE2_ROOT" -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} -mkdir -p $BUILD_ROOT -echo "Build Caffe2 raspbian into: $BUILD_ROOT" - -# obtain necessary dependencies -echo "Installing dependencies." -sudo apt-get install \ - cmake \ - libgflags-dev \ - libgoogle-glog-dev \ - libprotobuf-dev \ - protobuf-compiler - -# obtain optional dependencies that are usually useful to have. -echo "Installing optional dependencies." -sudo apt-get install \ - libpython-dev \ - python-numpy \ - python-pip \ - python-protobuf - -# Obtain python hypothesis, which Caffe2 uses for unit testing. Note that -# the one provided by apt-get is quite old so we install it via pip -sudo pip install hypothesis - -# Now, actually build the android target. -echo "Building caffe2" -cd $BUILD_ROOT - -# CUDA_USE_STATIC_CUDA_RUNTIME needs to be set to off so that opencv can be -# properly used. Otherwise, opencv will complain that opencv_dep_cudart cannot -# be found. -cmake "$CAFFE2_ROOT" -DCUDA_USE_STATIC_CUDA_RUNTIME=OFF \ - || exit 1 - -make -j 4 || exit 1 diff --git a/scripts/build_tizen.sh b/scripts/build_tizen.sh deleted file mode 100755 index 2262a2503c1d..000000000000 --- a/scripts/build_tizen.sh +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env bash -############################################################################## -# Example command to build the Tizen target (RPi3). -############################################################################## -# -# This script shows how one can build a Caffe2 binary for a Tizen device (RPi3). -# The build is essentially much similar to a host build, with one additional change -# which is to specify -mfpu=neon for optimized speed. - -setup_environment(){ -# The rootfs image for a Tizen target (RPi3)is located at the below webpage: -# https://cdn.download.tizen.org/archive/releases/milestone/tizen/4.0.m1/tizen-unified_20170529.1/images/ -# If you do not have a Tizen device, Please, run qemu-arm-static and chroot command. -# $ sudo chroot ~/tizen-rootfs qemu-arm-static /usr/bin/bash - -CAFFE2_ROOT="$( cd "$(dirname -- "$0")"/.. ; pwd -P)" -echo "Caffe2 codebase root is: $CAFFE2_ROOT" -BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build"} -mkdir -p $BUILD_ROOT -echo "Build Caffe2 Tizen into: $BUILD_ROOT" -} - -caffe2_lite_dep_packages(){ -# Obtain necessary dependencies -# You can set-up a rpm repository with zypper, yum, and dnf because Tizen -# software platform officially support rpm format such as Fedora, OpenSUSE. -# The official Tizen repository is as following: -# https://cdn.download.tizen.org/archive/releases/milestone/tizen/4.0.m1/ -echo "Installing dependencies." -sudo zypper install \ - make \ - strace \ - cmake \ - gcc* \ - binutils \ - glibc* \ - cpp \ - protobuf-devel \ - libstdc++* -} - -caffe2_lite_build(){ -# Now, actually build the android target. -echo "Building caffe2" -cd $BUILD_ROOT - -# Note: add more dependencies above if you need libraries such as leveldb, lmdb, etc. -# If you have to disable a specific package due to a package absence -# from https://git.tizen.org/cgit/, append -Dxxx_xxx=OFF option before executing cmake. -cmake .. \ - -DCMAKE_VERBOSE_MAKEFILE=1 \ - -DUSE_GFLAGS=OFF \ - -DUSE_GLOG=OFF -DUSE_NNPACK=OFF \ - -DRUN_HAVE_STD_REGEX=0 \ - -DRUN_HAVE_POSIX_REGEX=0 \ - -DHAVE_GNU_POSIX_REGEX=0 \ - -DUSE_MPI=OFF -DUSE_OPENMP=OFF \ - -DBUILD_PYTHON=OFF \ - -DUSE_GLOO=OFF \ - -DUSE_OPENCV=OFF \ - -DCAFFE2_CPU_FLAGS="-mfpu=neon -mfloat-abi=soft" \ - || exit 1 - -make -j`nproc` || exit 1 -} - -caffe2_full_dep_packages(){ -# Obtain necessary dependencies -# You can set-up a rpm repository with zypper, yum, and dnf because Tizen -# software platform officially support rpm format such as Fedora, OpenSUSE. -# The official Tizen repository is as following: -# https://cdn.download.tizen.org/archive/releases/milestone/tizen/4.0.m1/ -echo "Installing dependencies." -sudo zypper install \ - cmake \ - libgflags-dev \ - libgoogle-glog-dev \ - libprotobuf-dev \ - protobuf-compiler - -# Obtain optional dependencies that are usually useful to have. -echo "Installing optional dependencies." -sudo zypper install \ - libpython-dev \ - python-numpy \ - python-pip \ - python-protobuf - -# Obtain python hypothesis, which Caffe2 uses for unit testing. Note that -# the one provided by zypper is quite old so we install it via pip -sudo pip install hypothesis -} - -caffe2_full_build(){ -# Now, actually build the android target. -echo "Building caffe2" -cd $BUILD_ROOT - -# Note: add more dependencies above if you need libraries such as leveldb, lmdb, etc. -# If you have to disable a specific package due to a package absence -# from https://git.tizen.org/cgit/, append -Dxxx_xxx=OFF option before executing cmake. -cmake "$CAFFE2_ROOT" \ - -DCMAKE_VERBOSE_MAKEFILE=1 \ - -DUSE_CUDA=OFF \ - -DUSE_ITT=OFF \ - -DUSE_OPENCV=OFF \ - -DCAFFE2_CPU_FLAGS="-mfpu=neon -mfloat-abi=soft" \ - || exit 1 - -make -j`nproc` || exit 1 -} - -#### Main -# Setup a build environment to compile Caffe2 deeplearning framework in Tizen platform. -setup_environment -# There are two build options to support 'full' version and 'lite' version (by default). -caffe2_lite_dep_packages -caffe2_lite_build diff --git a/scripts/build_windows.bat b/scripts/build_windows.bat deleted file mode 100644 index 60bfebad08c0..000000000000 --- a/scripts/build_windows.bat +++ /dev/null @@ -1,80 +0,0 @@ -:: ############################################################################# -:: Example command to build on Windows. -:: ############################################################################# - -:: This script shows how one can build a Caffe2 binary for windows. - -@echo off -setlocal - -SET ORIGINAL_DIR=%cd% -SET CAFFE2_ROOT=%~dp0%.. - -if NOT DEFINED BUILD_BINARY ( - set BUILD_BINARY=OFF -) - -if NOT DEFINED BUILD_SHARED_LIBS ( - :: On CI, we test with BUILD_SHARED_LIBS=OFF. - :: By default, it will be BUILD_SHARED_LIBS=ON. - if NOT DEFINED BUILD_ENVIRONMENT ( - set BUILD_SHARED_LIBS=OFF - ) -) - -if NOT DEFINED CAFFE2_STATIC_LINK_CUDA ( - set CAFFE2_STATIC_LINK_CUDA=OFF -) - -if NOT DEFINED CMAKE_BUILD_TYPE ( - set CMAKE_BUILD_TYPE=Release -) - -if NOT DEFINED ONNX_NAMESPACE ( - set ONNX_NAMESPACE=onnx_c2 -) - -if NOT DEFINED TORCH_CUDA_ARCH_LIST ( - set TORCH_CUDA_ARCH_LIST=5.0 -) - -if NOT DEFINED USE_CUDA ( - set USE_CUDA=OFF -) - -if NOT DEFINED USE_OBSERVERS ( - set USE_OBSERVERS=OFF -) - -if NOT DEFINED MSVC_Z7_OVERRIDE ( - set MSVC_Z7_OVERRIDE=OFF -) - -if NOT DEFINED CMAKE_GENERATOR ( - set CMAKE_GENERATOR=Ninja -) - -set CMAKE_VERBOSE_MAKEFILE=1 - -:: Install pyyaml for Aten codegen -pip install pyyaml ninja - -echo CAFFE2_ROOT=%CAFFE2_ROOT% -echo CMAKE_GENERATOR=%CMAKE_GENERATOR% -echo CMAKE_BUILD_TYPE=%CMAKE_BUILD_TYPE% - -:: Set up cmake. We will skip building the test files right now. -pushd %CAFFE2_ROOT% -python tools\build_libtorch.py || goto :label_error -popd - -echo "Caffe2 built successfully" -cd %ORIGINAL_DIR% -endlocal -exit /b 0 - -:label_error -echo "Caffe2 building failed" -cd %ORIGINAL_DIR% -endlocal -exit /b 1 diff --git a/scripts/diagnose_protobuf.py b/scripts/diagnose_protobuf.py deleted file mode 100644 index 65af4618228d..000000000000 --- a/scripts/diagnose_protobuf.py +++ /dev/null @@ -1,92 +0,0 @@ -## @package diagnose_protobuf -# Module scripts.diagnose_protobuf -"""Diagnoses the current protobuf situation. - -Protocol buffer needs to be properly installed for Caffe2 to work, and -sometimes it is rather tricky. Specifically, we will need to have a -consistent version between C++ and python simultaneously. This is a -convenience script for one to quickly check if this is so on one's local -machine. - -Usage: - [set your environmental variables like PATH and PYTHONPATH] - python scripts/diagnose_protobuf.py -""" - -import os -import re -from subprocess import PIPE, Popen - - -# Get python protobuf version. -try: - import google.protobuf - - python_version = google.protobuf.__version__ - python_protobuf_installed = True -except ImportError: - print("DEBUG: cannot find python protobuf install.") - python_protobuf_installed = False - -if os.name == "nt": - protoc_name = "protoc.exe" -else: - protoc_name = "protoc" - -try: - p = Popen([protoc_name, "--version"], stdout=PIPE, stderr=PIPE) - out, err = p.communicate() -except: - print("DEBUG: did not find protoc binary.") - print("DEBUG: out: " + out) - print("DEBUG: err: " + err) - native_protobuf_installed = False -else: - if p.returncode: - print("DEBUG: protoc returned a non-zero return code.") - print("DEBUG: out: " + out) - print("DEBUG: err: " + err) - native_protobuf_installed = False - else: - tmp = re.search(r"\d\.\d\.\d", out) - if tmp: - native_version = tmp.group(0) - native_protobuf_installed = True - else: - print("DEBUG: cannot parse protoc version string.") - print("DEBUG: out: " + out) - native_protobuf_installed = False - -PYTHON_PROTOBUF_NOT_INSTALLED = """ -You have not installed python protobuf. Protobuf is needed to run caffe2. You -can install protobuf via pip or conda (if you are using anaconda python). -""" - -NATIVE_PROTOBUF_NOT_INSTALLED = """ -You have not installed the protoc binary. Protoc is needed to compile Caffe2 -protobuf source files. Depending on the platform you are on, you can install -protobuf via: - (1) Mac: using homebrew and do brew install protobuf. - (2) Linux: use apt and do apt-get install libprotobuf-dev - (3) Windows: install from source, or from the releases here: - https://github.com/google/protobuf/releases/ -""" - -VERSION_MISMATCH = f""" -Your python protobuf is of version {python_version} but your native protoc version is of -version {native_version}. This will cause the installation to produce incompatible -protobuf files. This is bad in general - consider installing the same version. -""" - -# Now, give actual recommendations -if not python_protobuf_installed: - print(PYTHON_PROTOBUF_NOT_INSTALLED) - -if not native_protobuf_installed: - print(NATIVE_PROTOBUF_NOT_INSTALLED) - -if python_protobuf_installed and native_protobuf_installed: - if python_version != native_version: - print(VERSION_MISMATCH) - else: - print("All looks good.") diff --git a/scripts/fbcode-dev-setup/ccache_setup.sh b/scripts/fbcode-dev-setup/ccache_setup.sh deleted file mode 100755 index cb461bee2dd2..000000000000 --- a/scripts/fbcode-dev-setup/ccache_setup.sh +++ /dev/null @@ -1,92 +0,0 @@ -#!/bin/bash - -# This script installs CCache with CUDA support. -# Example usage: -# ./ccache_setup.sh --path /installed/folder - -set -e -shopt -s expand_aliases - -# Setup the proxy -alias with_proxy="HTTPS_PROXY=http://fwdproxy:8080 HTTP_PROXY=http://fwdproxy:8080 FTP_PROXY=http://fwdproxy:8080 https_proxy=http://fwdproxy:8080 http_proxy=http://fwdproxy:8080 ftp_proxy=http://fwdproxy:8080 http_no_proxy='*.facebook.com|*.tfbnw.net|*.fb.com'" - -# Parse options -path="$HOME/ccache" -force=false - -while [[ $# -gt 0 ]]; do - case "$1" in - --path) - shift - path="$1" - path=$(realpath "$path") - ;; - --force) # Force install - force=true - ;; - --help) - echo 'usage: ./ccache_setup.py --path /installed/folder [--force]' - exit 0 - ;; - *) - echo "Invalid option: $1" - exit 1 - ;; - esac - shift -done - -# Check whether you put nvcc in PATH -set +e -nvcc_path=$(which nvcc) -if [[ -z "$nvcc_path" ]]; then - nvcc_path="/usr/local/cuda/bin/nvcc" - export PATH="/usr/local/cuda/bin:$PATH" -fi -set -e -if [ ! -f "$nvcc_path" ] && ! $force; then - # shellcheck disable=SC2016 - echo 'nvcc is not detected in $PATH' - exit 1 -fi -echo "nvcc is detected at $nvcc_path" - -if [ -f "$CUDA_NVCC_EXECUTABLE" ] && [[ "$CUDA_NVCC_EXECUTABLE" == *"ccache"* ]]; then # Heuristic rule - if $CUDA_NVCC_EXECUTABLE --version; then - if ! $force; then - echo "CCache with nvcc support is already installed at $CUDA_NVCC_EXECUTABLE, please add --force" - exit 0 - fi - fi -fi - -# Installing CCache -echo "CCache will be installed at $path" -if [ -e "$path" ]; then - mv --backup=t -T "$path" "${path}.old" -fi - -with_proxy git clone https://github.com/colesbury/ccache.git "$path" -b ccbin -cd "$path" -./autogen.sh -./configure -make install prefix="$path" - -mkdir -p "$path/lib" -mkdir -p "$path/cuda" -ln -sf "$path/bin/ccache" "$path/lib/cc" -ln -sf "$path/bin/ccache" "$path/lib/c++" -ln -sf "$path/bin/ccache" "$path/lib/gcc" -ln -sf "$path/bin/ccache" "$path/lib/g++" -ln -sf "$path/bin/ccache" "$path/cuda/nvcc" -"$path/bin/ccache" -M 25Gi - -# Make sure the nvcc wrapped in CCache is runnable -"$path/cuda/nvcc" --version -echo 'Congrats! The CCache with nvcc support is installed!' -echo -e "Please add the following lines to your bash init script:\\n" -echo "################ Env Var for CCache with CUDA support ################" -# shellcheck disable=SC2016 -echo 'export PATH="'"$path"'/lib:$PATH"' -echo 'export CUDA_NVCC_EXECUTABLE="'"$path"'/cuda/nvcc"' -echo '######################################################################' diff --git a/scripts/get_python_cmake_flags.py b/scripts/get_python_cmake_flags.py deleted file mode 100644 index a49debcc884a..000000000000 --- a/scripts/get_python_cmake_flags.py +++ /dev/null @@ -1,24 +0,0 @@ -## @package get_python_cmake_flags -# Module scripts.get_python_cmake_flags -############################################################################## -# Use this script to find your preferred python installation. -############################################################################## -# -# You can use the following to build with your preferred version of python -# if your installation is not being properly detected by CMake. -# -# mkdir -p build && cd build -# cmake $(python ../scripts/get_python_cmake_flags.py) .. -# make -# - - -import sys -import sysconfig - - -flags = [ - f"-DPython_EXECUTABLE:FILEPATH={sys.executable}", -] - -print(" ".join(flags), end="") diff --git a/scripts/release_notes/README.md b/scripts/release_notes/README.md index 6cd34da87b14..c88533f937e7 100644 --- a/scripts/release_notes/README.md +++ b/scripts/release_notes/README.md @@ -130,7 +130,7 @@ This part is a little tedious but it seems to work. May want to explore using pa 5. Install the google doc extension [docs to markdown](https://github.com/evbacher/gd2md-html) 6. Start to compile back down these markdown files into a single markdown file. -`TODO`: This is by far the most manual process and is ripe for automation. If the next person up would like to investigate Google Doc APIS there is some room hor improvement here. +`TODO`: This is by far the most manual process and is ripe for automation. If the next person up would like to investigate Google Doc APIS there is some room for improvement here. ### Part 4: Cherry Picks @@ -187,7 +187,7 @@ You will then create a release at [Pytorch Release](https://github.com/pytorch/p #### Tidbits You will probably have a release note that doesn't fit into the character limit of github. I used the following regex: -`\[#(\d+)\]\(https://github.com/pytorch/pytorch/pull/\d+\)` to replace the full lunks to (#). +`\[#(\d+)\]\(https://github.com/pytorch/pytorch/pull/\d+\)` to replace the full links to (#). This will get formatted correctly in the github UI and can be checked when creating a draft release. diff --git a/scripts/remove_apache_header.sh b/scripts/remove_apache_header.sh deleted file mode 100755 index 97980bfbb0ef..000000000000 --- a/scripts/remove_apache_header.sh +++ /dev/null @@ -1,13 +0,0 @@ -if [[ "$1" == *.py ]]; then - apache_header="apache_python.txt" -else - apache_header="apache_header.txt" -fi -apache_lines=$(wc -l < "${apache_header}") -apache_md5=$(cat "${apache_header}" | md5) -header_md5=$(head -n ${apache_lines} $1 | md5) -if [ "${header_md5}" == "${apache_md5}" ]; then - keep_lines=$(($(wc -l < $1) - ${apache_lines})) - tail -n ${keep_lines} $1 > _remove_apache_header.txt - mv _remove_apache_header.txt $1 -fi diff --git a/scripts/run_lintrunner.py b/scripts/run_lintrunner.py new file mode 100644 index 000000000000..60d5b545cf91 --- /dev/null +++ b/scripts/run_lintrunner.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +""" +Pre‑push hook wrapper for Lintrunner. + +✓ Stores a hash of .lintrunner.toml in the venv +✓ Re-runs `lintrunner init` if that file's hash changes +""" + +from __future__ import annotations + +import hashlib +import os +import shutil +import subprocess +import sys +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[1] +LINTRUNNER_TOML_PATH = REPO_ROOT / ".lintrunner.toml" + +# This is the path to the pre-commit-managed venv +VENV_ROOT = Path(sys.executable).parent.parent +# Stores the hash of .lintrunner.toml from the last time we ran `lintrunner init` +INITIALIZED_LINTRUNNER_TOML_HASH_PATH = VENV_ROOT / ".lintrunner_plugins_hash" + + +def ensure_lintrunner() -> None: + """Fail if Lintrunner is not on PATH.""" + if shutil.which("lintrunner"): + print("✅ lintrunner is already installed") + return + sys.exit( + "❌ lintrunner is required but was not found on your PATH. Please run the `python scripts/setup_hooks.py` to install to configure lintrunner before using this script. If `git push` still fails, you may need to open an new terminal" + ) + + +def ensure_virtual_environment() -> None: + """Fail if not running within a virtual environment.""" + in_venv = ( + os.environ.get("VIRTUAL_ENV") is not None + or hasattr(sys, "real_prefix") + or (hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix) + ) + + if not in_venv: + sys.exit( + "❌ This script must be run from within a virtual environment. " + "Please activate your virtual environment before running this script." + ) + + +def compute_file_hash(path: Path) -> str: + """Returns SHA256 hash of a file's contents.""" + hasher = hashlib.sha256() + with path.open("rb") as f: + while chunk := f.read(8192): + hasher.update(chunk) + return hasher.hexdigest() + + +def read_stored_hash(path: Path) -> str | None: + if not path.exists(): + return None + try: + return path.read_text().strip() + except Exception: + return None + + +def initialize_lintrunner_if_needed() -> None: + """Runs lintrunner init if .lintrunner.toml changed since last run.""" + if not LINTRUNNER_TOML_PATH.exists(): + print("⚠️ No .lintrunner.toml found. Skipping init.") + return + + print( + f"INITIALIZED_LINTRUNNER_TOML_HASH_PATH = {INITIALIZED_LINTRUNNER_TOML_HASH_PATH}" + ) + current_hash = compute_file_hash(LINTRUNNER_TOML_PATH) + stored_hash = read_stored_hash(INITIALIZED_LINTRUNNER_TOML_HASH_PATH) + + if current_hash == stored_hash: + print("✅ Lintrunner plugins already initialized and up to date.") + return + + print("🔁 Running `lintrunner init` …", file=sys.stderr) + subprocess.check_call(["lintrunner", "init"]) + INITIALIZED_LINTRUNNER_TOML_HASH_PATH.write_text(current_hash) + + +def main() -> None: + # 0. Ensure we're running in a virtual environment + ensure_virtual_environment() + print(f"🐍 Virtual env being used: {VENV_ROOT}", file=sys.stderr) + + # 1. Ensure lintrunner binary is available + ensure_lintrunner() + + # 2. Check for plugin updates and re-init if needed + initialize_lintrunner_if_needed() + + # 3. Run lintrunner with any passed arguments and propagate its exit code + args = sys.argv[1:] # Forward all arguments to lintrunner + result = subprocess.call(["lintrunner"] + args) + sys.exit(result) + + +if __name__ == "__main__": + main() diff --git a/scripts/setup_hooks.py b/scripts/setup_hooks.py new file mode 100644 index 000000000000..41f08d45e98b --- /dev/null +++ b/scripts/setup_hooks.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 +""" +Bootstrap Git pre‑push hook. + +✓ Requires uv to be installed (fails if not available) +✓ Installs/updates pre‑commit with uv (global, venv‑proof) +✓ Registers the repo's pre‑push hook and freezes hook versions + +Run this from the repo root (inside or outside any project venv): + + python scripts/setup_hooks.py +""" + +from __future__ import annotations + +import shutil +import subprocess +import sys +from pathlib import Path +from typing import Tuple + + +# ─────────────────────────────────────────── +# Helper utilities +# ─────────────────────────────────────────── +def run(cmd: list[str]) -> None: + print(f"$ {' '.join(cmd)}") + subprocess.check_call(cmd) + + +def which(cmd: str) -> bool: + return shutil.which(cmd) is not None + + +def ensure_uv() -> None: + if which("uv"): + # Ensure the path uv installs binaries to is part of the system path + print("$ uv tool update-shell") + result = subprocess.run( + ["uv", "tool", "update-shell"], capture_output=True, text=True + ) + if result.returncode == 0: + # Check if the output indicates changes were made + if ( + "Updated" in result.stdout + or "Added" in result.stdout + or "Modified" in result.stdout + ): + print( + "⚠️ Shell configuration updated. You may need to restart your terminal for changes to take effect." + ) + elif result.stdout.strip(): + print(result.stdout) + return + else: + sys.exit( + f"❌ Warning: uv tool update-shell failed: {result.stderr}. uv installed tools may not be available." + ) + + sys.exit( + "\n❌ uv is required but was not found on your PATH.\n" + " Please install uv first using the instructions at:\n" + " https://docs.astral.sh/uv/getting-started/installation/\n" + " Then rerun python scripts/setup_hooks.py\n" + ) + + +def ensure_tool_installed( + tool: str, force_update: bool = False, python_ver: Tuple[int, int] = None +) -> None: + """ + Checks to see if the tool is available and if not (or if force update requested) then + it reinstalls it. + + Returns: Whether or not the tool is available on PATH. If it's not, a new terminal + needs to be opened before git pushes work as expected. + """ + if force_update or not which(tool): + print(f"Ensuring latest {tool} via uv …") + command = ["uv", "tool", "install", "--force", tool] + if python_ver: + # Add the Python version to the command if specified + command.extend(["--python", f"{python_ver[0]}.{python_ver[1]}"]) + run(command) + if not which(tool): + print( + f"\n⚠️ {tool} installation succeed, but it's not on PATH. Launch a new terminal if your git pushes don't work.\n" + ) + + +if sys.platform.startswith("win"): + print( + "\n⚠️ Lintrunner is not supported on Windows, so there are no pre-push hooks to add. Exiting setup.\n" + ) + sys.exit(0) + +# ─────────────────────────────────────────── +# 1. Install dependencies +# ─────────────────────────────────────────── + +ensure_uv() + +# Ensure pre-commit is installed globally via uv +ensure_tool_installed("pre-commit", force_update=True, python_ver=(3, 9)) + +# Don't force a lintrunner update because it might break folks +# who already have it installed in a different way +ensure_tool_installed("lintrunner") + +# ─────────────────────────────────────────── +# 2. Activate (or refresh) the pre‑push hook +# ─────────────────────────────────────────── + +# ── Activate (or refresh) the repo’s pre‑push hook ────────────────────────── +# Creates/overwrites .git/hooks/pre‑push with a tiny shim that will call +# `pre-commit run --hook-stage pre-push` on every `git push`. +# This is why we need to install pre-commit globally. +# +# The --allow-missing-config flag lets pre-commit succeed if someone changes to +# a branch that doesn't have pre-commit installed +run( + [ + "uv", + "tool", + "run", + "pre-commit", + "install", + "--hook-type", + "pre-push", + "--allow-missing-config", + ] +) + +# ── Pin remote‑hook versions for reproducibility ──────────────────────────── +# (Note: we don't have remote hooks right now, but it future-proofs this script) +# 1. `autoupdate` bumps every remote hook’s `rev:` in .pre-commit-config.yaml +# to the latest commit on its default branch. +# 2. `--freeze` immediately rewrites each `rev:` to the exact commit SHA, +# ensuring all contributors and CI run identical hook code. +run(["uv", "tool", "run", "pre-commit", "autoupdate", "--freeze"]) + + +print( + "\n✅ pre‑commit is installed globally via uv and the pre‑push hook is active.\n" + " Lintrunner will now run automatically on every `git push`.\n" +) diff --git a/scripts/temp.sh b/scripts/temp.sh deleted file mode 100755 index 18eb2b473381..000000000000 --- a/scripts/temp.sh +++ /dev/null @@ -1,7 +0,0 @@ -find ../caffe2 -name "*.py" -exec ./remove_apache_header.sh {} \; -find ../caffe2 -name "*.h" -exec ./remove_apache_header.sh {} \; -find ../caffe2 -name "*.cc" -exec ./remove_apache_header.sh {} \; -find ../caffe2 -name "*.cpp" -exec ./remove_apache_header.sh {} \; -find ../caffe2 -name "*.cu" -exec ./remove_apache_header.sh {} \; -find ../caffe2 -name "*.mm" -exec ./remove_apache_header.sh {} \; -find ../caffe2 -name "*.m" -exec ./remove_apache_header.sh {} \; diff --git a/scripts/xcode_build.rb b/scripts/xcode_build.rb deleted file mode 100644 index 0734167bdda1..000000000000 --- a/scripts/xcode_build.rb +++ /dev/null @@ -1,76 +0,0 @@ -require 'optparse' -require 'xcodeproj' - -options = {} -option_parser = OptionParser.new do |opts| - opts.banner = 'Tools for building PyTorch iOS framework on MacOS' - opts.on('-i', '--install_path ', 'path to the cmake install folder') { |value| - options[:install] = value - } - opts.on('-x', '--xcodeproj_path ', 'path to the XCode project file') { |value| - options[:xcodeproj] = value - } - opts.on('-p', '--platform ', 'platform for the current build, OS or SIMULATOR') { |value| - options[:platform] = value - } -end.parse! -puts options.inspect - -install_path = File.expand_path(options[:install]) -if not Dir.exist? (install_path) - raise "path don't exist:#{install_path}!" -end -xcodeproj_path = File.expand_path(options[:xcodeproj]) -if not File.exist? (xcodeproj_path) - raise "path don't exist:#{xcodeproj_path}!" -end - -project = Xcodeproj::Project.open(xcodeproj_path) -target = project.targets.first #TestApp -header_search_path = ['$(inherited)', "#{install_path}/include"] -libraries_search_path = ['$(inherited)', "#{install_path}/lib"] -other_linker_flags = ['$(inherited)', "-all_load"] - -target.build_configurations.each do |config| - config.build_settings['HEADER_SEARCH_PATHS'] = header_search_path - config.build_settings['LIBRARY_SEARCH_PATHS'] = libraries_search_path - config.build_settings['OTHER_LDFLAGS'] = other_linker_flags - config.build_settings['ENABLE_BITCODE'] = 'No' -end - -# link static libraries -target.frameworks_build_phases.clear -libs = ['libc10.a', 'libclog.a', 'libpthreadpool.a', 'libXNNPACK.a', 'libmicrokernels-prod.a', 'libeigen_blas.a', 'libcpuinfo.a', 'libpytorch_qnnpack.a', 'libtorch_cpu.a', 'libtorch.a', 'libkineto.a'] -for lib in libs do - path = "#{install_path}/lib/#{lib}" - if File.exist?(path) - libref = project.frameworks_group.new_file(path) - target.frameworks_build_phases.add_file_reference(libref) - end -end -# link system frameworks -frameworks = ['CoreML', 'Metal', 'MetalPerformanceShaders', 'Accelerate', 'UIKit'] -if frameworks - frameworks.each do |framework| - path = "System/Library/Frameworks/#{framework}.framework" - framework_ref = project.frameworks_group.new_reference(path) - framework_ref.name = "#{framework}.framework" - framework_ref.source_tree = 'SDKROOT' - target.frameworks_build_phases.add_file_reference(framework_ref) - end -end -project.save - -sdk = nil -arch = nil -if options[:platform] == 'SIMULATOR' - sdk = 'iphonesimulator' - arch = 'arm64' -elsif options[:platform] == 'OS' - sdk = 'iphoneos' - arch = 'arm64' -else - raise "unsupported platform #{options[:platform]}" -end - -exec "xcodebuild clean build -project #{xcodeproj_path} -alltargets -sdk #{sdk} -configuration Release -arch #{arch}" diff --git a/setup.py b/setup.py index 232cf214d32e..076d3d1e7ec0 100644 --- a/setup.py +++ b/setup.py @@ -1233,9 +1233,9 @@ def main() -> None: "typing-extensions>=4.10.0", 'setuptools ; python_version >= "3.12"', "sympy>=1.13.3", - "networkx", + "networkx>=2.5.1", "jinja2", - "fsspec", + "fsspec>=0.8.5", ] if BUILD_PYTHON_ONLY: install_requires += [f"{LIBTORCH_PKG_NAME}=={TORCH_VERSION}"] diff --git a/test/ao/sparsity/test_activation_sparsifier.py b/test/ao/sparsity/test_activation_sparsifier.py index 9c2a10d3355a..8e1525b85879 100644 --- a/test/ao/sparsity/test_activation_sparsifier.py +++ b/test/ao/sparsity/test_activation_sparsifier.py @@ -50,7 +50,7 @@ def _check_constructor(self, activation_sparsifier, model, defaults, sparse_conf sparsifier_defaults = activation_sparsifier.defaults combined_defaults = {**defaults, "sparse_config": sparse_config} - # more keys are populated in activation sparsifier (eventhough they may be None) + # more keys are populated in activation sparsifier (even though they may be None) assert len(combined_defaults) <= len(activation_sparsifier.defaults) for key, config in sparsifier_defaults.items(): diff --git a/test/ao/sparsity/test_data_sparsifier.py b/test/ao/sparsity/test_data_sparsifier.py index 999438215743..5217049aafdf 100644 --- a/test/ao/sparsity/test_data_sparsifier.py +++ b/test/ao/sparsity/test_data_sparsifier.py @@ -265,7 +265,7 @@ def check_memory_reference(self, data_list, data_with_config, defaults, **kwargs class _NormDataSparsifierTestCase(_BaseDataSparsiferTestCase): r"""This helper test class takes in any supported type of and runs some tests. This inherits the TestBaseDataSparsifierRuner wherein some functions are - over-ridden to take accomodate the specific sparsifier. + over-ridden to take accommodate the specific sparsifier. TODO: Change the structure by creating a separate test case class for each member function """ @@ -770,7 +770,7 @@ def test_ptq_quantize_first(self): # higher threshold as quantization occurs before sparsity threshold = ( - 1 # zero points seem to have higher magnitude with sparsity occuring after + 1 # zero points seem to have higher magnitude with sparsity occurring after ) sl_emb1 = (torch.abs(dequant_emb1) < threshold).float().mean() diff --git a/test/ao/sparsity/test_scheduler.py b/test/ao/sparsity/test_scheduler.py index 38e8fca4cdd8..b563efac73bd 100644 --- a/test/ao/sparsity/test_scheduler.py +++ b/test/ao/sparsity/test_scheduler.py @@ -188,7 +188,7 @@ def test_step(self): self.assertEqual( self._get_sparsity_levels(sparsifier), self.sorted_sparse_levels, - msg="Sparsity level is not reaching the target level afer delta_t * n steps ", + msg="Sparsity level is not reaching the target level after delta_t * n steps ", ) diff --git a/test/benchmark_utils/test_benchmark_utils.py b/test/benchmark_utils/test_benchmark_utils.py index 1d8d8e7e3594..f9120c26a132 100644 --- a/test/benchmark_utils/test_benchmark_utils.py +++ b/test/benchmark_utils/test_benchmark_utils.py @@ -699,14 +699,16 @@ def custom_transforms(fn: str): 8959166 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] ... 92821 /tmp/build/80754af9/python_15996 ... a3/envs/throwaway/bin/python3.6] - 91000 build/../torch/csrc/tensor/pytho ... ch/torch/lib/libtorch_python.so] + 91000 build/../torch/csrc/tensor/pytho ... ch/torch/lib/libtorch_python.so] # codespell:ignore 91000 /data/users/test_user/repos/pyto ... nsors::get_default_scalar_type() 90090 ???:pthread_mutex_lock [/usr/lib64/libpthread-2.28.so] 90000 build/../c10/core/TensorImpl.h:c ... ch/torch/lib/libtorch_python.so] 90000 build/../aten/src/ATen/record_fu ... torch/torch/lib/libtorch_cpu.so] 90000 /data/users/test_user/repos/pyto ... uard(std::optional) 90000 /data/users/test_user/repos/pyto ... ersionCounter::~VersionCounter() - 88000 /data/users/test_user/repos/pyto ... ratorKernel*, at::Tensor const&)""", + 88000 /data/users/test_user/repos/pyto ... ratorKernel*, at::Tensor const&)""".replace( + " # codespell:ignore", "" + ), ) self.regularizeAndAssertExpectedInline( diff --git a/test/cpp/aoti_abi_check/test_exception.cpp b/test/cpp/aoti_abi_check/test_exception.cpp new file mode 100644 index 000000000000..74a9fee5d986 --- /dev/null +++ b/test/cpp/aoti_abi_check/test_exception.cpp @@ -0,0 +1,19 @@ +#include + +#include + +namespace torch { +namespace aot_inductor { + +TEST(TestExceptions, TestStdTorchCheck) { + EXPECT_NO_THROW(STD_TORCH_CHECK(true, "dummy true message")); + EXPECT_NO_THROW(STD_TORCH_CHECK(true, "dummy ", "true ", "message")); + EXPECT_THROW( + STD_TORCH_CHECK(false, "dummy false message"), std::runtime_error); + EXPECT_THROW( + STD_TORCH_CHECK(false, "dummy ", "false ", "message"), + std::runtime_error); +} + +} // namespace aot_inductor +} // namespace torch diff --git a/test/cpp/aoti_inference/test.cpp b/test/cpp/aoti_inference/test.cpp index 59d575b2cc2b..bff3827f8e8a 100644 --- a/test/cpp/aoti_inference/test.cpp +++ b/test/cpp/aoti_inference/test.cpp @@ -144,6 +144,8 @@ void test_aoti_package_loader_multi_gpu( const std::string& device, bool use_runtime_constant_folding) { torch::NoGradGuard no_grad; + // Ensure that this test will reset the default CUDA device on exit. + torch::DeviceGuard device_guard(c10::Device("cuda")); std::string data_path = (std::filesystem::path(STRINGIZE(CMAKE_CURRENT_BINARY_DIR)) / "data.pt") diff --git a/test/cpp/api/transformer.cpp b/test/cpp/api/transformer.cpp index 6062c77f5917..fc4832d30157 100644 --- a/test/cpp/api/transformer.cpp +++ b/test/cpp/api/transformer.cpp @@ -73,7 +73,7 @@ void transformer_encoder_layer_test_helper( ASSERT_TRUE( torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true)); - // all 0 values are NOT masked. This should't mask anything + // all 0 values are NOT masked. This shouldn't mask anything torch::Tensor mask = torch::tensor({{0}}, tensor_options) == 1; result = model( encoder_input, diff --git a/test/cpp/c10d/ProcessGroupNCCLTest.cpp b/test/cpp/c10d/ProcessGroupNCCLTest.cpp index 56f67035a5fb..ac4ba4da0157 100644 --- a/test/cpp/c10d/ProcessGroupNCCLTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLTest.cpp @@ -28,7 +28,7 @@ class NCCLTestBase { NCCLTestBase(NCCLTestBase&& other) noexcept = default; - std::shared_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() { + ::c10::intrusive_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() { return pg_; } @@ -39,7 +39,7 @@ class NCCLTestBase { void initialize( int rank, size_t size, - std::optional<::std::shared_ptr<::c10d::ProcessGroupNCCL>> split_from = + std::optional<::c10::intrusive_ptr<::c10d::ProcessGroupNCCL>> split_from = std::nullopt) { store_ = c10::make_intrusive<::c10d::FileStore>(path_, size); @@ -52,13 +52,13 @@ class NCCLTestBase { opts->split_color = ++color_; } #endif - pg_ = std::make_unique<::c10d::ProcessGroupNCCL>( + pg_ = c10::make_intrusive<::c10d::ProcessGroupNCCL>( store_, rank, size, std::move(opts)); } protected: std::string path_; - std::shared_ptr<::c10d::ProcessGroupNCCL> pg_; + ::c10::intrusive_ptr<::c10d::ProcessGroupNCCL> pg_; std::chrono::milliseconds pgTimeout_; ::c10::intrusive_ptr<::c10d::Store> store_; int color_{1}; @@ -767,8 +767,8 @@ TEST_F(ProcessGroupNCCLTest, CUDAEventCache) { } // Test that the CUDAEventCache can be used to create CUDA events and reuse. - auto event1 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(true); - auto event2 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(false); + auto event1 = c10d::CUDAEventCache::get(1)->create(true); + auto event2 = c10d::CUDAEventCache::get(1)->create(false); auto event1_ptr = event1.get(); auto event2_ptr = event2.get(); @@ -777,14 +777,14 @@ TEST_F(ProcessGroupNCCLTest, CUDAEventCache) { event2 = nullptr; // Test that the CUDAEventCache is indeed reused. - auto event3 = c10d::ProcessGroupNCCL::CUDAEventCache::get(2)->create(true); - auto event4 = c10d::ProcessGroupNCCL::CUDAEventCache::get(2)->create(false); + auto event3 = c10d::CUDAEventCache::get(2)->create(true); + auto event4 = c10d::CUDAEventCache::get(2)->create(false); // The cache has been used up, new events should be created. - auto event5 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(true); - auto event6 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(false); + auto event5 = c10d::CUDAEventCache::get(1)->create(true); + auto event6 = c10d::CUDAEventCache::get(1)->create(false); // The cache has been used up, new events should be created. - auto event7 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(true); - auto event8 = c10d::ProcessGroupNCCL::CUDAEventCache::get(1)->create(false); + auto event7 = c10d::CUDAEventCache::get(1)->create(true); + auto event8 = c10d::CUDAEventCache::get(1)->create(false); EXPECT_NE(event1_ptr, event3.get()); EXPECT_NE(event2_ptr, event4.get()); EXPECT_EQ(event1_ptr, event5.get()); diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index d192d8a6c5d3..f58d81ed008a 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -17,7 +17,7 @@ set(BACKEND_WITH_COMPILER_SRCS ) if(USE_KINETO) # Testing edge profiler for backend use - # profiler_edge should only be aded when USE_KINETO flag is on + # profiler_edge should only be added when USE_KINETO flag is on list(APPEND BACKEND_WITH_COMPILER_SRCS ${TORCH_SRC_DIR}/csrc/jit/mobile/profiler_edge.cpp) endif() diff --git a/test/cpp/jit/test_backend.cpp b/test/cpp/jit/test_backend.cpp index dd4df40d9c13..4a060e436f2b 100644 --- a/test/cpp/jit/test_backend.cpp +++ b/test/cpp/jit/test_backend.cpp @@ -789,7 +789,7 @@ TEST( c._save_for_mobile(ss, ExtraFilesMap(), true); auto c_loaded = _load_for_mobile(ss); /* - * Erro stack trace will look like this: + * Error stack trace will look like this: * Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA) * Traceback of TorchScript (most recent call last): * File "", line 3, in FunctionName_UNKNOWN diff --git a/test/cpp/jit/test_backend_compiler_lib.cpp b/test/cpp/jit/test_backend_compiler_lib.cpp index 33262efd1e2b..55511c3e684a 100644 --- a/test/cpp/jit/test_backend_compiler_lib.cpp +++ b/test/cpp/jit/test_backend_compiler_lib.cpp @@ -79,7 +79,7 @@ class BackendWithCompiler : public PyTorchBackendInterface { // forwards everything along. In a non toy setup this could grab information // from that runtime that might be relevant to execute, such as build flags // the resolution of the devices camera, or basically any runtime specific - // information that wouldnt be available server side where preprocess is + // information that wouldn't be available server side where preprocess is // called. c10::impl::GenericDict compile( c10::IValue processed, diff --git a/test/cpp/jit/test_lite_trainer.cpp b/test/cpp/jit/test_lite_trainer.cpp index a09374065306..950d0c524ad3 100644 --- a/test/cpp/jit/test_lite_trainer.cpp +++ b/test/cpp/jit/test_lite_trainer.cpp @@ -78,7 +78,7 @@ TEST(LiteTrainerTest, Params) { AT_ASSERT(parameters[0].item() == bc_parameters[0].item()); } -// TODO Renable these tests after parameters are correctly loaded on mobile +// TODO Re-enable these tests after parameters are correctly loaded on mobile /* TEST(MobileTest, NamedParameters) { Module m("m"); diff --git a/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp b/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp index 088a4eb04c99..b6467b7c5b49 100644 --- a/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp +++ b/test/cpp/lite_interpreter_runtime/test_lite_interpreter_runtime.cpp @@ -106,7 +106,7 @@ TEST(RunTimeTest, DelegateException) { * inputs.emplace_back(torch::rand({2, 4})); * inputs.emplace_back(torch::rand({13, 9})); * Run with inputs and expect exception - * Erro stack trace will look like this: + * Error stack trace will look like this: * Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA) * Traceback of TorchScript (most recent call last): * File "", line 3, in FunctionName_UNKNOWN diff --git a/test/cpp/nativert/CMakeLists.txt b/test/cpp/nativert/CMakeLists.txt index 10b750d8b39a..b6e6cd20ced7 100644 --- a/test/cpp/nativert/CMakeLists.txt +++ b/test/cpp/nativert/CMakeLists.txt @@ -24,6 +24,15 @@ set(NATIVERT_TEST_SRCS ${TORCH_ROOT}/torch/nativert/executor/memory/LayoutPlanner.cpp ${TORCH_ROOT}/torch/nativert/executor/memory/LayoutManager.cpp ${TORCH_ROOT}/torch/nativert/executor/memory/AliasAnalyzer.cpp + ${TORCH_ROOT}/torch/nativert/executor/Executor.cpp + ${TORCH_ROOT}/torch/nativert/kernels/KernelFactory.cpp + ${TORCH_ROOT}/torch/nativert/executor/ConstantFolder.cpp + ${TORCH_ROOT}/torch/nativert/executor/GraphExecutorBase.cpp + ${TORCH_ROOT}/torch/nativert/executor/SerialGraphExecutor.cpp + ${TORCH_ROOT}/torch/nativert/executor/ParallelGraphExecutor.cpp + ${TORCH_ROOT}/torch/nativert/kernels/AutoFunctionalizeKernel.cpp + ${TORCH_ROOT}/torch/nativert/kernels/CallTorchBindKernel.cpp + ${TORCH_ROOT}/torch/nativert/kernels/HigherOrderKernel.cpp ) add_executable(test_nativert diff --git a/test/cpp/nativert/test_alias_analyzer.cpp b/test/cpp/nativert/test_alias_analyzer.cpp new file mode 100644 index 000000000000..afa469f58c8b --- /dev/null +++ b/test/cpp/nativert/test_alias_analyzer.cpp @@ -0,0 +1,182 @@ +#include + +#include + +#include +#include + +#include +#include + +using namespace ::testing; +using namespace torch::nativert; + +using AliasTestCase = std::tuple< + std::string /* value */, + AllocationLifetime, + bool /* is_alias */, + bool /* is_storage_associated_with_output */, + c10::FastSet /* source(s) */>; + +class AliasAnalyzerTests : public testing::Test { + void SetUp() override {} + + void TearDown() override { + test_cases.clear(); + model.clear(); + } + + public: + void setTestCases(std::vector cases) { + test_cases = std::move(cases); + } + + void setModel(std::string m) { + model = std::move(m); + } + + void run() { + EXPECT_FALSE(test_cases.empty()); + EXPECT_FALSE(model.empty()); + + ExecutorConfig cfg; + cfg.enableStaticCPUKernels = true; + + auto graph = stringToGraph(model); + auto kernels = KernelFactory().initializeNodeKernels( + *graph, nullptr, cfg, {}, nullptr); + auto kernelSchemas = Executor::getKernelSchemas(kernels.nodeKernels); + + AliasAnalyzer analyzer(*graph, kernelSchemas); + + for ( + auto& [value, lifetime, is_alias, is_storage_associated_with_output, srcs] : + test_cases) { + LOG(INFO) << fmt::format( + "running test: value={}, lifetime=({}, {}), is_alias={}, is_storage_associated_with_output={}, src={}", + value, + lifetime.start, + lifetime.end, + is_alias, + is_storage_associated_with_output, + srcs.empty() ? "{}" + : std::accumulate( + srcs.begin(), + srcs.end(), + std::string{}, + [](std::string cur, const std::string& src) { + cur.append(","); + cur.append(src); + return cur; + })); + auto* v = graph->getValue(value); + EXPECT_EQ(analyzer.lifetime(v), lifetime); + EXPECT_EQ(analyzer.is_alias(v), is_alias); + EXPECT_EQ( + analyzer.is_storage_associated_with_output(v), + is_storage_associated_with_output); + const auto* resolved_srcs = analyzer.get_sources_of_alias(v); + if (resolved_srcs /* ensure set equality between *resolved_srcs and srcs */) { + EXPECT_FALSE(srcs.empty()); + EXPECT_EQ(resolved_srcs->size(), srcs.size()); + for (const auto& resolved_src : *resolved_srcs) { + EXPECT_TRUE(srcs.erase(std::string(resolved_src->name())) == 1); + } + EXPECT_TRUE(srcs.empty()); + } else { + EXPECT_TRUE(srcs.empty()); + } + } + } + + private: + std::string model; + std::vector test_cases; +}; + +TEST_F(AliasAnalyzerTests, TestNoAlias) { + setModel(R"( + graph(%y0, %y1): + %out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1) + %res = torch.ops.aten.clone.default(self=%out_t, memory_format=None) + return (%res))"); + + setTestCases({ + {"out_t", AllocationLifetime(1, 2), false, false, {}}, + {"res", AllocationLifetime(2, 3), false, true, {}}, + }); + + run(); +} + +TEST_F(AliasAnalyzerTests, TestSimpleAlias) { + setModel(R"( + graph(%y0, %y1): + %out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1) + %res = torch.ops.aten.slice.Tensor(self=%out_t, dim=1, start=0, end=0, step=1) + return (%res))"); + + setTestCases({ + {"out_t", AllocationLifetime(1, 3), false, true, {}}, + {"res", AllocationLifetime(2, 3), true, false, {"out_t"}}, + }); + + run(); +} + +TEST_F(AliasAnalyzerTests, TestDeepAlias) { + setModel(R"( + graph(%y0, %y1): + %out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1) + %a1 = torch.ops.aten.slice.Tensor(self=%out_t, dim=1, start=0, end=0, step=1) + %res = torch.ops.aten.slice.Tensor(self=%a1, dim=1, start=0, end=0, step=1) + return (%res))"); + + setTestCases({ + {"out_t", AllocationLifetime(1, 4), false, true, {}}, + {"a1", AllocationLifetime(2, 4), true, false, {"out_t"}}, + {"res", AllocationLifetime(3, 4), true, false, {"out_t"}}, + }); + + run(); +} + +TEST_F(AliasAnalyzerTests, TestPackedListUnpack) { + setModel(R"( + graph(%a, %b, %c, %d): + %input_list[] = prim.ListPack(l0=%a, l1=%b, l2=%c, l3=%d) + %x0, %x1, %x2, %x3 = prim.ListUnpack(input=%input_list) + return (%x1, %x3))"); + + setTestCases({ + {"a", AllocationLifetime(0, 2), false, false, {}}, + {"x0", AllocationLifetime(2, 2), true, false, {"a"}}, + {"b", AllocationLifetime(0, 3), false, true, {}}, + {"x1", AllocationLifetime(2, 3), true, false, {"b"}}, + {"c", AllocationLifetime(0, 2), false, false, {}}, + {"x2", AllocationLifetime(2, 2), true, false, {"c"}}, + {"d", AllocationLifetime(0, 3), false, true, {}}, + {"x3", AllocationLifetime(2, 3), true, false, {"d"}}, + }); + + run(); +} + +TEST_F(AliasAnalyzerTests, TestAmbiguousSourceOfAlias) { + setModel(R"( + graph(%y0, %y1): + %out_t = torch.ops.aten.matmul.default(self=%y0, other=%y1) + %out_t2 = torch.ops.aten.matmul.default(self=%y0, other=%y1) + %a1 = prim.VarStack(l0=%out_t, l1=%out_t2) + %res = torch.ops.aten.slice.Tensor(self=%a1, dim=1, start=0, end=0, step=1) + return (%res))"); + + setTestCases({ + {"out_t", AllocationLifetime(1, 5), false, true, {}}, + {"out_t2", AllocationLifetime(2, 5), false, true, {}}, + {"a1", AllocationLifetime(3, 5), true, false, {"out_t", "out_t2"}}, + {"res", AllocationLifetime(4, 5), true, false, {"out_t", "out_t2"}}, + }); + + run(); +} diff --git a/test/cpp/nativert/test_itree.cpp b/test/cpp/nativert/test_itree.cpp index e0004f7db77e..4748c11c3e17 100644 --- a/test/cpp/nativert/test_itree.cpp +++ b/test/cpp/nativert/test_itree.cpp @@ -259,7 +259,7 @@ TEST(ITreeTest, NoContext) { c10::IValue(8), c10::IValue(9), }; - ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed"); + EXPECT_THROW({ itreeUnflatten(flats, spec); }, c10::Error); } TEST(ITreeTest, TooManyContext) { @@ -304,7 +304,7 @@ TEST(ITreeTest, TooManyContext) { c10::IValue(8), c10::IValue(9), }; - ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed"); + EXPECT_THROW({ itreeUnflatten(flats, spec); }, c10::Error); } TEST(ITreeTest, DoubleRegister) { @@ -375,7 +375,7 @@ TEST(ITreeTest, NotEnoughUnflatten) { c10::IValue(2), c10::IValue(7), }; - ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed"); + EXPECT_THROW({ itreeUnflatten(flats, spec); }, c10::Error); } TEST(ITreeTest, TooManyUnflatten) { @@ -449,7 +449,7 @@ TEST(ITreeTest, TooManyUnflatten) { c10::IValue(2), c10::IValue(7), }; - ASSERT_DEATH({ itreeUnflatten(flats, spec); }, "Check failed"); + EXPECT_THROW({ itreeUnflatten(flats, spec); }, c10::Error); } TEST(ITreeTest, Flatten) { @@ -908,8 +908,8 @@ TEST(ITreeTest, UnmatchedDictFlatten) { list.push_back(std::move(tup)); list.push_back(c10::IValue(2)); list.push_back(std::move(dict)); - ASSERT_DEATH( - { itreeFlatten(c10::IValue{std::move(list)}, spec); }, "Check failed"); + EXPECT_THROW( + { itreeFlatten(c10::IValue{std::move(list)}, spec); }, c10::Error); } TEST(ITreeTest, DictFlattenTest) { @@ -1025,8 +1025,8 @@ TEST(ITreeTest, UnmatchedTupleFlatten) { list.push_back(std::move(tup)); list.push_back(c10::IValue(2)); list.push_back(std::move(dict)); - ASSERT_DEATH( - { itreeFlatten(c10::IValue{std::move(list)}, spec); }, "Check failed"); + EXPECT_THROW( + { itreeFlatten(c10::IValue{std::move(list)}, spec); }, c10::Error); } TEST(ITreeTest, ToAtenType) { diff --git a/test/cpp/nativert/test_placement.cpp b/test/cpp/nativert/test_placement.cpp index e88ae20e1de0..ab65bfc07b91 100644 --- a/test/cpp/nativert/test_placement.cpp +++ b/test/cpp/nativert/test_placement.cpp @@ -8,23 +8,6 @@ using namespace ::testing; namespace torch::nativert { -TEST(PlacementTest, NormalizeDevice) { - c10::Device cpuDevice = c10::Device(c10::DeviceType::CPU); - c10::Device cpuDevice1 = c10::Device(c10::DeviceType::CPU); - cpuDevice1.set_index(1); - - EXPECT_EQ(normalizeDevice(cpuDevice), cpuDevice); - EXPECT_NE(normalizeDevice(cpuDevice1), cpuDevice1); - - c10::Device cudaDevice = c10::Device(c10::DeviceType::CUDA); - c10::Device cudaDevice1 = c10::Device(c10::DeviceType::CUDA, 1); - EXPECT_EQ(normalizeDevice(cudaDevice), c10::Device(c10::DeviceType::CUDA, 0)); - EXPECT_EQ( - normalizeDevice(cudaDevice1), c10::Device(c10::DeviceType::CUDA, 1)); - - EXPECT_NE( - normalizeDevice(cudaDevice1), c10::Device(c10::DeviceType::CUDA, 0)); -} TEST(PlacementTest, IsSameDevice) { c10::Device cpuDevice = c10::Device(c10::DeviceType::CPU); @@ -90,11 +73,11 @@ TEST(PlacementTest, Placement) { {c10::Device("cuda:0"), c10::Device("cuda:1")}}; Placement p1(deviceMap1); EXPECT_EQ(p1.getMappedDevice(c10::Device("cpu")), c10::Device("cpu")); - EXPECT_EQ(p1.getMappedDevice(c10::Device("cuda")), c10::Device("cuda:1")); + EXPECT_EQ(p1.getMappedDevice(c10::Device("cuda")), c10::Device("cuda")); EXPECT_EQ(p1.getMappedDevice(c10::Device("cuda:0")), c10::Device("cuda:1")); std::unordered_map deviceMap2 = { - {c10::Device("cpu"), c10::Device("cuda")}}; + {c10::Device("cpu"), c10::Device("cuda:0")}}; Placement p2(deviceMap2); EXPECT_EQ(p2.getMappedDevice(c10::Device("cpu")), c10::Device("cuda:0")); EXPECT_EQ(p2.getMappedDevice(c10::Device("cuda:0")), c10::Device("cuda:0")); diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp index 8a96c68dc75e..2e1e84e758db 100644 --- a/test/cpp/tensorexpr/test_cuda.cpp +++ b/test/cpp/tensorexpr/test_cuda.cpp @@ -1681,7 +1681,7 @@ TEST(Cuda, MaskMultiDim_CUDA) { // Tests the case where loop extents are symbolic and not known at compile time. // In this case both stores must be masked against the extent of the other loop, -// incase it is larger. +// in case it is larger. TEST(Cuda, MaskMultiDimSymbolic_CUDA) { VarHandle OUTER_SIZE("OUTER_SIZE", kLong); VarHandle A_SIZE("A_SIZE", kLong); diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index 22f6b64efe1a..dc67928b111a 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -38,12 +38,12 @@ TEST_F(Kernel, ParallelExternalCallBuf) { %4 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::matmul(%3, %2) return (%4))IR"; auto graph = std::make_shared(); - torch::jit::parseIR(graph_string, &*graph); + torch::jit::parseIR(graph_string, graph.get()); +#ifdef TORCH_ENABLE_LLVM const std::string& verification_pattern = R"IR( # CHECK: for (int64_t i = 0ll; i < 5000ll; i++) /* parallel */{)IR"; -#ifdef TORCH_ENABLE_LLVM TensorExprKernel k(graph); StmtPtr s = k.getCodeGenStmt(); std::ostringstream oss; @@ -1113,7 +1113,7 @@ TEST_F(Kernel, Softmax2D) { const auto verification_pattern = format(verification_template, ver_env); - // verification sting temporarily disabled until + // verification string temporarily disabled until // inlining of exp() is benchmarked and determined // torch::jit::testing::FileCheck().run(verification_pattern, // oss.str()); @@ -1192,7 +1192,7 @@ TEST_F(Kernel, Softmax3D) { ver_env.d("softmax_dim_size", softmax_dim_size); const auto verification_pattern = format(verification_template, ver_env); - // verification sting temporarily disabled until + // verification string temporarily disabled until // inlining of exp() is benchmarked and determined // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); @@ -1275,7 +1275,7 @@ TEST_F(Kernel, Softmax4D) { ver_env.d("softmax_dim_size", softmax_dim_size); const auto verification_pattern = format(verification_template, ver_env); - // verification sting temporarily disabled until + // verification string temporarily disabled until // inlining of exp() is benchmarked and determined // torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); @@ -1887,7 +1887,7 @@ graph(%x : int, auto rt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong)); // Verify that TEK::runFast works correctly with mixed scalar and tensor - // inputs/utputs + // inputs/outputs std::vector inputs = {&x, xt.data_ptr(), &y, yt.data_ptr()}; std::vector outputs = {&r, rt.data_ptr(), &z, zt.data_ptr()}; k.runFast(inputs, outputs); @@ -1897,7 +1897,7 @@ graph(%x : int, ASSERT_TRUE(at::equal(rt, zt * xt)); // Verify that TEK::run works correctly with mixed scalar and tensor - // inputs/utputs + // inputs/outputs std::vector stack = {x, xt, y, yt}; k.run(stack); TORCH_CHECK_EQ(stack[0], x * y * x); diff --git a/test/cpp/tensorexpr/test_memdependency.cpp b/test/cpp/tensorexpr/test_memdependency.cpp index cac7283f2beb..5db84eab1f50 100644 --- a/test/cpp/tensorexpr/test_memdependency.cpp +++ b/test/cpp/tensorexpr/test_memdependency.cpp @@ -437,12 +437,12 @@ TEST(MemDependency, BoundSubtractMultiDim) { ASSERT_TRUE(EQ( subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 10), CB(0, 20)}), {})); - // Mutli dim one way partial in dim 1. + // Multi dim one way partial in dim 1. ASSERT_TRUE( EQ(subtractIndicesBounds({CB(0, 9), CB(0, 2)}, {CB(0, 3), CB(0, 2)}), {{CB(4, 9), CB(0, 2)}})); - // Mutli dim one way partial in dim 2. + // Multi dim one way partial in dim 2. ASSERT_TRUE( EQ(subtractIndicesBounds({CB(0, 9), CB(0, 20)}, {CB(0, 9), CB(0, 10)}), {{CB(0, 9), CB(11, 20)}})); @@ -939,7 +939,7 @@ TEST(MemDependency, MemDependencyCheckerLoopBounds) { */ // Now let's look at the bounds of each access. - // There are 9 accesses in this Stmt, so this is exhaustive, we wont do this + // There are 9 accesses in this Stmt, so this is exhaustive, we won't do this // much. auto history = analyzer.getHistory(); ASSERT_EQ(history.size(), 10); @@ -1134,7 +1134,7 @@ TEST(MemDependency, MemDependencyCheckerLoopBoundsIndexShift) { // this case -1. ASSERT_TRUE(EQ(history[1]->bounds(), {CB(0, 8)})); // It depends on the input, but also the store in the same loop, since - // different interations of the loop depend on each other. + // different iterations of the loop depend on each other. ASSERT_EQ(history[1]->dependencies().size(), 2); ASSERT_TRUE(history[1]->hasDependency(history[0])); ASSERT_TRUE(history[1]->hasDependency(history[2])); diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp index bdc744ae4e03..fb83ab85b71e 100644 --- a/test/cpp/tensorexpr/test_reductions.cpp +++ b/test/cpp/tensorexpr/test_reductions.cpp @@ -333,7 +333,7 @@ TEST(Reductions, ReduceMinCustomInitializer) { cg.call({in, out, std::numeric_limits::max()}); ASSERT_EQ(out[0], 10); - // With an initalizer lower than the min, that's the min. + // With an initializer lower than the min, that's the min. cg.call({in, out, 5.f}); ASSERT_EQ(out[0], 5); } diff --git a/test/cpp/tensorexpr/test_registerizer.cpp b/test/cpp/tensorexpr/test_registerizer.cpp index d6f5977789a9..6cbd04264c32 100644 --- a/test/cpp/tensorexpr/test_registerizer.cpp +++ b/test/cpp/tensorexpr/test_registerizer.cpp @@ -1254,7 +1254,7 @@ TEST(Registerizer, RegisterizerConditionInsideOverlap2) { * A[0] = 3; * A[x] = (A[x]) + 1; * } - * int A_2 = A[x]; // A_2 initialier + * int A_2 = A[x]; // A_2 initializer * B[x] = A_2; // * B[x + 1] = A_2; // * A_2 = C[x]; // @@ -3064,7 +3064,7 @@ TEST(Registerizer, RegisterizerHiddenAccessNo) { } // In this case the conditional access must be hoisted by two loops, there are -// two accesses here one is unhidden and the other isnt. A[0] can be +// two accesses here one is unhidden and the other isn't. A[0] can be // registerized but B[0] cannot. TEST(Registerizer, RegisterizerHiddenAccessMultiLoop) { BufHandle a("A", {10}, kInt); @@ -3422,8 +3422,8 @@ TEST(Registerizer, RegisterizerMultiDim) { torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } -// Wont registerize if only some dims match, but will still registerize distinct -// elements. +// Won't registerize if only some dims match, but will still registerize +// distinct elements. TEST(Registerizer, RegisterizerMultiDimPartial) { BufHandle a("A", {3, 4, 5}, kInt); VarHandle x("x", kInt); diff --git a/test/cpp/tensorexpr/test_simplify.cpp b/test/cpp/tensorexpr/test_simplify.cpp index 99a00d0d62c1..7ca2b74eaa76 100644 --- a/test/cpp/tensorexpr/test_simplify.cpp +++ b/test/cpp/tensorexpr/test_simplify.cpp @@ -2643,7 +2643,7 @@ TEST(Simplify, SimplifyWontReorderFloat) { VarHandle x("x", kFloat); VarHandle y("y", kFloat); // x%y - (x%y - 1) => x%y - (x%y - 1). - // We wont reorder opaque ops if they are FP. + // We won't reorder opaque ops if they are FP. ExprHandle body = (x % y) - ((x % y) - 1); ExprHandle simplified = IRSimplifier::simplify(body); @@ -2794,7 +2794,7 @@ TEST(Simplify, SimplifyRoundModPattern) { } { - // Sanity checking we wont do the optimization on floats. + // Sanity checking we won't do the optimization on floats. VarHandle x("x", kFloat); VarHandle y("y", kFloat); ExprHandle body = ((x / y) * y) + (x % y); @@ -2811,7 +2811,7 @@ TEST(Simplify, SimplifyRoundModPattern) { } { - // Sanity check we wont do it if the mod term doesn't match. + // Sanity check we won't do it if the mod term doesn't match. VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); @@ -2821,7 +2821,7 @@ TEST(Simplify, SimplifyRoundModPattern) { } { - // Sanity check we wont do it if the div term doesn't match. + // Sanity check we won't do it if the div term doesn't match. VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); @@ -2831,7 +2831,7 @@ TEST(Simplify, SimplifyRoundModPattern) { } { - // Sanity check we wont do it if the mul term doesn't match. + // Sanity check we won't do it if the mul term doesn't match. VarHandle x("x", kInt); VarHandle y("y", kInt); VarHandle z("z", kInt); @@ -3013,7 +3013,7 @@ TEST(Simplify, SimplifyModRoundModPattern) { } { - // Sanity checking we wont do the optimization on floats. + // Sanity checking we won't do the optimization on floats. VarHandle x("x", kFloat); VarHandle y("y", kFloat); VarHandle z("z", kFloat); @@ -4264,7 +4264,7 @@ TEST(Simplify, SimplifyReorderForCond) { { // Condition uses distinct region of Tensor. - // We could reorder here wih better analysis, but we don't. Included for + // We could reorder here with better analysis, but we don't. Included for // completeness. auto body = For::make( i, @@ -4643,7 +4643,7 @@ TEST(Simplify, SimplifyFuseConditions) { } { - // Sanity check wont fuse different non-CompareSelects. + // Sanity check won't fuse different non-CompareSelects. auto body = Block::make( {Cond::make(i, Store::make(a, {0}, i), nullptr), Cond::make(j, Store::make(a, {1}, i), nullptr)}); diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 554203752479..a46974e511d5 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -1,6 +1,8 @@ #include #include #include +#include +#include #include @@ -32,6 +34,8 @@ Tensor sgd_out_of_place( const float weight_decay, const double lr, const bool maximize) { + STD_TORCH_CHECK(param.dim() == 1, "param must be 1D"); + int64_t *param_sizes; int64_t *param_strides; aoti_torch_get_sizes(param.get(), ¶m_sizes); @@ -254,3 +258,50 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("is_contiguous", &boxed_is_contiguous); } + +Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) { + return transpose(t, dim0, dim1); +} + +void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + auto res = my_transpose(to(stack[0]), to(stack[1]), to(stack[2])); + + stack[0] = from(res); +} + +Tensor my_empty_like(Tensor t) { + return empty_like(t); +} + +void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + auto res = my_empty_like(to(stack[0])); + stack[0] = from(res); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { + m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor"); + m.def("my_empty_like(Tensor t) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { + m.impl("my_transpose", &boxed_my_transpose); + m.impl("my_empty_like", &boxed_empty_like); +} + + +Tensor my_zero_(Tensor t) { + return zero_(t); +} + +void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + auto res = my_zero_(to(stack[0])); + stack[0] = from(res); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { + m.def("my_zero_(Tensor(a!) t) -> Tensor(a!)"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { + m.impl("my_zero_", &boxed_my_zero_); +} diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index 2b4fbd40eb1a..371d8b455e18 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -116,3 +116,39 @@ def is_contiguous(t) -> bool: Returns: is_contiguous(t) """ return torch.ops.libtorch_agnostic.is_contiguous.default(t) + + +def my_transpose(t, dim0, dim1) -> Tensor: + """ + Returns t.transpose(dim0, dim1) + + Args: + t: Tensor + + Returns: my_transpose(t, dim0, dim1) + """ + return torch.ops.libtorch_agnostic.my_transpose.default(t, dim0, dim1) + + +def my_empty_like(t) -> Tensor: + """ + Returns t.empty_like() + + Args: + t: Tensor + + Returns: my_empty_like(t) + """ + return torch.ops.libtorch_agnostic.my_empty_like.default(t) + + +def my_zero_(t) -> Tensor: + """ + Returns t.zero_() + + Args: + t: Tensor + + Returns: my_zero_(t) + """ + return torch.ops.libtorch_agnostic.my_zero_.default(t) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index ba1d6411b098..e1b62a8d3c3c 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -173,6 +173,40 @@ def _make_cuda_tensors(prior_mem): curr_mem = torch.cuda.memory_allocated(device) self.assertEqual(curr_mem, init_mem) + def test_my_transpose(self, device): + import libtorch_agnostic + + t = torch.rand(2, 7, device=device) + out = libtorch_agnostic.ops.my_transpose(t, 0, 1) + self.assertEqual(out, torch.transpose(t, 0, 1)) + + with self.assertRaisesRegex(RuntimeError, "API call failed"): + libtorch_agnostic.ops.my_transpose(t, 1, 2) + + def test_my_empty_like(self, device): + import libtorch_agnostic + + deterministic = torch.are_deterministic_algorithms_enabled() + try: + # set use_deterministic_algorithms to fill unintialized memory + torch.use_deterministic_algorithms(True) + + t = torch.rand(2, 7, device=device) + out = libtorch_agnostic.ops.my_empty_like(t) + self.assertTrue(id(out != id(t))) + self.assertEqual(out, torch.empty_like(t)) + finally: + torch.use_deterministic_algorithms(deterministic) + + @onlyCPU + def test_my_zero_(self, device): + import libtorch_agnostic + + t = torch.rand(2, 7, device=device) + out = libtorch_agnostic.ops.my_zero_(t) + self.assertEqual(id(out), id(t)) + self.assertEqual(out, torch.zeros_like(t)) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/test/cpp_extensions/open_registration_extension/README.md b/test/cpp_extensions/open_registration_extension/README.md deleted file mode 100644 index cf32c3afbb06..000000000000 --- a/test/cpp_extensions/open_registration_extension/README.md +++ /dev/null @@ -1,37 +0,0 @@ -# PyTorch OpenReg - -This folder contains a self-contained example of a PyTorch out-of-tree backend leveraging the "PrivateUse1" backend from core. - -## How to use - -Install as standalone with `python -m pip install -e .` (or `python -m pip install .`) -from this folder. You can run test via `python {PYTORCH_ROOT_PATH}/test/test_openreg.py`. - -## Design principles - -For simplicity anything that can be implemented from python is done so. -A real implementation will most likely want to call these different APIs from c++ directly. - -The current version sends everything back to python and contains enough implementation to run basic model, transfer host/device and printing. - -The codebase is split as follows: - -- `pytorch_openreg/__init__.py` - - imports torch to get core state initialized. - - imports `._aten_impl` to register our aten op implementations to torch. - - imports `.C` to load our c++ extension that registers more ops, allocator and hooks. - - renames the PrivateUse1 backend and register our python-side module. -- `pytorch_openreg/_aten_impl.py` - - Define a new `torch.Library` that registers a fallback that will be called whenever a backend kernel for PrivateUse1 is called. It contains the logic to handle all kind of native functions, computing the output metadata, allocating it and only calling into the device daemon to perform computation. -- `pytorch_openreg/_device_daemon.py` - - contains the Allocator (responsible for allocating memory on the device side and host side, as int8 buffers). - - contains `Driver`, which as user-process driver to deal with some information needed to be done in driver. - - contains `Executor`, which as device-process exector to do something related device logic. -- `pytorch_openreg/_meta_parser.py` mainly contain utilities to send objects over the wire from the user process to the device process. - - The main class there is `OpenRegTensorMeta` that contains all the metadata sent to the device which should be enough for it to populate the output Tensor. - -## Next steps - -The main next step would be to: - -- Replace the current `open_registration_extension.cpp` test in PyTorch CI with this. diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py deleted file mode 100644 index 05b8955b6557..000000000000 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py +++ /dev/null @@ -1,122 +0,0 @@ -import types - -import torch - -# Create our python implementation dict so that the C++ module -# can access it during its initialization and also register aten impls. -from ._aten_impl import impl_factory as impl_factory # noqa: F401 -from ._device_daemon import driver - - -# Load the C++ Module -import pytorch_openreg._C # isort:skip # type: ignore[import] # noqa: F401 - - -def _create_module(): - module = types.ModuleType("_OpenRegMod") - - class device: - r"""Context-manager that changes the selected device. - - Args: - device (torch.device or int): device index to select. It's a no-op if - this argument is a negative integer or ``None``. - """ - - def __init__(self, device): - self.idx = torch.accelerator._get_device_index(device, optional=True) - self.prev_idx = -1 - - def __enter__(self): - self.prev_idx = driver.exec("exchangeDevice", self.idx) - - def __exit__(self, type, value, traceback): - self.idx = driver.exec("uncheckedSetDevice", self.prev_idx) - return False - - def device_count() -> int: - return driver.exec("deviceCount") - - def is_available(): - return True - - def current_device(): - return torch.accelerator.current_device_index() - - def get_rng_state(device="openreg"): - if isinstance(device, str): - device = torch.device(device) - elif isinstance(device, int): - device = torch.device("openreg", device) - idx = device.index - if idx is None: - idx = current_device() - default_generator = pytorch_openreg._C._get_default_generator(idx) - return default_generator.get_state() - - def set_rng_state(new_state, device="openreg"): - if isinstance(device, str): - device = torch.device(device) - elif isinstance(device, int): - device = torch.device("openreg", device) - idx = device.index - if idx is None: - idx = current_device() - default_generator = pytorch_openreg._C._get_default_generator(idx) - default_generator.set_state(new_state) - - def initial_seed() -> int: - _lazy_init() - idx = current_device() - default_generator = pytorch_openreg._C._get_default_generator(idx) - return default_generator.initial_seed() - - def manual_seed(seed: int) -> None: - seed = int(seed) - - idx = current_device() - default_generator = pytorch_openreg._C._get_default_generator(idx) - default_generator.manual_seed(seed) - - def manual_seed_all(seed: int) -> None: - seed = int(seed) - - for idx in range(device_count()): - default_generator = pytorch_openreg._C._get_default_generator(idx) - default_generator.manual_seed(seed) - - def is_initialized(): - return module._initialized - - def _is_in_bad_fork(): - return False - - def _lazy_init(): - if is_initialized(): - return - pytorch_openreg._C._init() - module._initialized = True - - module.is_available = is_available # type: ignore[assignment] - - module._initialized = False # type: ignore[assignment] - module._lazy_init = _lazy_init # type: ignore[assignment] - module.is_initialized = is_initialized # type: ignore[assignment] - - module.device = device # type: ignore[assignment] - module.device_count = device_count # type: ignore[assignment] - module.current_device = current_device # type: ignore[assignment] - module.get_rng_state = get_rng_state # type: ignore[assignment] - module.set_rng_state = set_rng_state # type: ignore[assignment] - module._is_in_bad_fork = _is_in_bad_fork # type: ignore[assignment] - module.initial_seed = initial_seed # type: ignore[assignment] - module.manual_seed = manual_seed # type: ignore[assignment] - module.manual_seed_all = manual_seed_all # type: ignore[assignment] - - return module - - -# Set all the appropriate state on PyTorch -torch.utils.rename_privateuse1_backend("openreg") -torch._register_device_module("openreg", _create_module()) -torch.utils.generate_methods_for_privateuse1_backend(for_storage=True) diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py deleted file mode 100644 index d4c49bd28d45..000000000000 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py +++ /dev/null @@ -1,186 +0,0 @@ -import logging - -import torch -from torch.utils._pytree import tree_any - - -log = logging.getLogger(__name__) - -from ._device_daemon import driver -from ._meta_parser import prepare_for_sending, to_device_no_copy - - -_IMPL_REGISTRY = {} - - -def impl_factory(name): - if name in _IMPL_REGISTRY: - return _IMPL_REGISTRY[name] - - def _(*args, **kwargs): - log.info("Calling hook %s", name) - return driver.exec(name, *args, **kwargs) - - _IMPL_REGISTRY[name] = _ - return _ - - -def _openreg_kernel_fallback(op, *args, **kwargs): - def get_tensor_device(*args): - for arg in args: - if isinstance(arg, torch.Tensor) and arg.device.type == "openreg": - return arg.device - - device = get_tensor_device(*args) - if device is None: - return _kernel_fallback(op, *args, **kwargs) - - # Mimicks the DeviceGuard system we have in aten - with torch.openreg.device(device): # type: ignore[misc] - return _kernel_fallback(op, *args, **kwargs) - - -def _kernel_fallback(op, *args, **kwargs): - log.info("Calling kernel %s", op) - - op_name = None - post_process = None - if "out" in op._overloadname: - # Note that all structured native op will call here - if isinstance(kwargs["out"], tuple): - raise RuntimeError(f"out= variant {op} with tuple out= not supported") - if kwargs["out"].nelement() == 0: - # Out variant that needs a resize, convert to an out of place - # and handle generically below - orig_out = kwargs["out"] - del kwargs["out"] - if op._overloadname != "out": - raise RuntimeError( - "Cannot retranslate non-default out= variant form 0 size" - ) - op = op.overloadpacket.default - - def _post_process(): - nonlocal real_res - orig_out.set_(real_res) - real_res = orig_out - - post_process = _post_process - - else: - # No metadata update to do, just run the op on the device - op_name = op.overloadpacket._qualified_op_name - real_res = kwargs["out"] - elif not tree_any(lambda obj: isinstance(obj, torch.Tensor), (args, kwargs)): - # No Tensor argument means factory function - # They should decompose and be handled in our c++ side directly - raise RuntimeError(f"{op} not handled yet.") - elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default: - # Only handle inplace ops returning their first arg - assert len(args) >= 1, f"Inplace {op} needs at least one arg" - assert len(op._schema.returns) == 1, ( - f"NYI Inplace {op} with more than one return" - ) - op_name = op.overloadpacket._qualified_op_name - real_res = args[0] - elif any(r.alias_info is not None for r in op._schema.returns): - # View ops - if op is torch.ops.aten.view.default: - return torch.ops.aten._unsafe_view(*args, **kwargs) - raise RuntimeError(f"{op} view op is not handled yet") - - if op_name is None: - # 1. Compute updated metadata - if torch.Tag.dynamic_output_shape not in op.tags: - # Usual case: run the meta op to see the output metadata - meta_args, meta_kwargs = to_device_no_copy("meta", args, kwargs) - meta_res = op(*meta_args, **meta_kwargs) - - # 2. Allocate the output - real_res, _ = to_device_no_copy("openreg", meta_res, {}) - else: - # Slow version for data-dependent functions: - # Run the op on the device just to get the output shape - args_, kwargs_ = prepare_for_sending(args, kwargs) - shape = driver.exec( - "get_op_output_shape", - op.overloadpacket._qualified_op_name, - args_, - kwargs_, - ) - - # 2. Allocate the output - real_res = args[0].new(shape) - - # 3. Move to out variant - kwargs["out"] = real_res - # Let overload resolution find the out= overload - op_name = op.overloadpacket._qualified_op_name - - # 4. Run the compute and populate the output on the device - args, kwargs = prepare_for_sending(args, kwargs) - driver.exec("run_op", op_name, args, kwargs) - - if post_process is not None: - post_process() - - return real_res - - -def copy_from_device(from_): - with torch.openreg.device(from_.device): # type: ignore[misc] - args, _ = prepare_for_sending((from_,), {}) - return driver.exec("send_data", *args) - - -def copy_from_host_to_device(from_, to_): - with torch.openreg.device(to_.device): # type: ignore[misc] - args, _ = prepare_for_sending((to_,), {}) - driver.exec("recv_data", from_, *args) - return to_ - - -def _copy_from(from_, to_): - if from_.device.type == to_.device.type: - assert from_.device.type == "openreg" - if from_.device.index == to_.device.index: - op = torch.ops.aten.copy_.default - return _openreg_kernel_fallback(op, to_, from_) - else: - host_mem = copy_from_device(from_) - return copy_from_host_to_device(host_mem, to_) - elif from_.device.type == "openreg": - host_mem = copy_from_device(from_) - return to_.copy_(host_mem) - elif to_.device.type == "openreg": - return copy_from_host_to_device(from_, to_) - else: - raise RuntimeError("Should not happen") - - -def _set_source_tensor(ten1, ten2): - return torch.ops.aten.set_.source_Storage_storage_offset( - ten1, - ten2.untyped_storage(), - ten2.storage_offset(), - ten2.size(), - ten2.stride(), - ) - - -def _local_scalar_dense(ten): - host_mem = copy_from_device(ten) - return host_mem.item() - - -_openreg_lib = torch.library.Library("_", "IMPL") -_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1") - -_openreg_lib_aten = torch.library.Library("aten", "IMPL") -_openreg_lib_aten.impl("_copy_from", _copy_from, dispatch_key="PrivateUse1") -_openreg_lib_aten.impl( - "set_.source_Tensor", _set_source_tensor, dispatch_key="PrivateUse1" -) -_openreg_lib_aten.impl( - "_local_scalar_dense", _local_scalar_dense, dispatch_key="PrivateUse1" -) diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py deleted file mode 100644 index d33986963500..000000000000 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py +++ /dev/null @@ -1,391 +0,0 @@ -import ctypes -import logging -import threading -import time - -import torch - -from ._meta_parser import ( - OpenRegTensorData, - receive_after_sending, - safe_str, - validate_send_queue_args, -) - - -log = logging.getLogger(__name__) -mp_context = torch.multiprocessing.get_context("spawn") - -# Constant properties of our device -NUM_DEVICES = 2 - - -# Our allocator -class Allocator: - def __init__(self): - self.allocated = {} - - def malloc(self, size): - mem = ctypes.create_string_buffer(size) - ptr = ctypes.addressof(mem) - self.allocated[ptr] = (size, mem) - return ptr - - def free(self, ptr): - if ptr not in self.allocated: - return False - else: - del self.allocated[ptr] - return True - - -class HostAllocator(Allocator): - def is_pinned_ptr(self, ptr): - return ptr in self.allocated or any( - ptr_ <= ptr and ptr < ptr_ + size - for ptr_, (size, _) in self.allocated.items() - ) - - -class DeviceAllocator(Allocator): - def tensor_from_meta(self, meta): - def create_tensor_from_data_ptr(ptr, size): - storage = torch._C._construct_storage_from_data_pointer( - ptr, torch.device("cpu"), size - ) - return torch.Tensor(storage) - - found_base = None - # Usual case, we're receiving a known Tensor - if meta.data_ptr in self.allocated: - found_base = create_tensor_from_data_ptr( - meta.data_ptr, self.allocated[meta.data_ptr][0] - ) - - # Might be a rewrap of another storage at a different offset - # Slow path to try and find the corresponding storage - if found_base is None: - for tag, (size, _) in self.allocated.items(): - # t is always a 1D uint8 storage! - if meta.data_ptr > tag and meta.data_ptr < tag + size: - # Blame @ngimel for this - slice_size = size - (meta.data_ptr - tag) - found_base = create_tensor_from_data_ptr(meta.data_ptr, slice_size) - - # Might be an empty tensor - if found_base is None and meta.nelem_in_bytes == 0: - found_base = torch.tensor((), dtype=torch.uint8) - - # This pointer is not allocated here, segfault ! - if found_base is None: - log.info("Currently allocated blocks:\n %s", safe_str(self.allocated)) - log.info("Trying to access %s", meta) - raise RuntimeError("SEGFAULT!") - - # Raw 1d uint8 data - raw = found_base - # Reinterpret cast in the right dtype - as_dtype = raw.view(dtype=meta.dtype) - # View to the right shape/stride/offset - view = as_dtype.as_strided(meta.size, meta.stride, meta.storage_offset) - return view - - -def register(registry): - def func(fn): - registry[fn.__name__] = fn - return fn - - return func - - -class Driver: - def __init__(self, num_devices): - super().__init__() - self.num_devices = num_devices - self.is_initialized = False - - # State of our driver - self.curr_device_idx = 0 - self.curr_streams = {} - - # Allocated memory belongs to which device - self.memory_belong = {} - self.host_allocator = HostAllocator() - self.event_belong = {} - - self.rlock = threading.RLock() - - def _lazy_init(self): - if self.is_initialized: - return - self.devices = [] - - for i in range(self.num_devices): - req_queue = mp_context.Queue() - ans_queue = mp_context.Queue() - runner = mp_context.Process( - target=_Executor(i).run_forever, - args=(req_queue, ans_queue), - daemon=True, - ) - runner.start() - self.devices.append((req_queue, ans_queue, runner)) - - self.is_initialized = True - - def exec(self, cmd, *args): - with self.rlock: - log.info("Main process launched: %s(*%s)", cmd, safe_str(args)) - - if cmd in Driver.registry: - res = Driver.registry[cmd](self, *args) - else: - res = self.run_on_executor(self.curr_device_idx, cmd, *args) - - log.info("Main process result for %s received: %s", cmd, safe_str(res)) - if res == "ERROR": - raise RuntimeError(f"Error in daemon while executing {cmd}, see logs") - else: - return res - - def run_on_executor(self, device_idx, cmd, *args): - self._lazy_init() - req_queue, ans_queue, _ = self.devices[device_idx] - stream = self.getStream(device_idx) - validate_send_queue_args(cmd, args) - req_queue.put((stream, cmd) + args) - return ans_queue.get() - - registry = {} - - @register(registry) - def hasPrimaryContext(self, device_idx): - return device_idx >= 0 and device_idx < self.num_devices - - @register(registry) - def deviceCount(self, *args): - assert len(args) == 0 - return self.num_devices - - @register(registry) - def getDevice(self): - return self.curr_device_idx - - @register(registry) - def setDevice(self, device_idx): - assert device_idx >= 0 and device_idx < self.num_devices - self.curr_device_idx = device_idx - - @register(registry) - def uncheckedSetDevice(self, *args): - assert len(args) == 1 - self.curr_device_idx = int(args[0]) - - @register(registry) - def exchangeDevice(self, *args): - assert len(args) == 1 - res = self.curr_device_idx - self.curr_device_idx = int(args[0]) - return res - - @register(registry) - def malloc(self, size): - ptr = self.run_on_executor(self.curr_device_idx, "malloc", size) - self.memory_belong[ptr] = self.curr_device_idx - return ptr - - @register(registry) - def free(self, ptr): - device_idx = self.memory_belong.pop(ptr, None) - if device_idx is None: - return False - return self.run_on_executor(device_idx, "free", ptr) - - @register(registry) - def isPinnedPtr(self, ptr): - return self.host_allocator.is_pinned_ptr(ptr) - - @register(registry) - def hostMalloc(self, size): - return self.host_allocator.malloc(size) - - @register(registry) - def hostFree(self, ptr): - return self.host_allocator.free(ptr) - - @register(registry) - def getNewStream(self, device_idx, priority): - return self.run_on_executor(device_idx, "getNewStream", priority) - - @register(registry) - def queryStream(self, stream): - return self.run_on_executor( - stream.device_index, "queryStream", stream.stream_id - ) - - @register(registry) - def getStream(self, device_idx): - return self.curr_streams.get(device_idx, 0) - - @register(registry) - def exchangeStream(self, stream): - stream_id = self.curr_streams.get(stream.device_index, 0) - self.curr_streams[stream.device_index] = stream.stream_id - return stream_id - - @register(registry) - def synchronizeStream(self, stream): - self.run_on_executor(stream.device_index, "synchronizeStream", stream.stream_id) - - @register(registry) - def record(self, event, stream, device_index, flags): - event_ptr = ctypes.cast(event, ctypes.POINTER(ctypes.c_int64)) - # Create event if needed - if event_ptr.contents.value == 0: - event_ptr.contents.value = self.run_on_executor( - stream.device_index, "eventCreateWithFlags", flags - ) - self.event_belong[event_ptr.contents.value] = stream.device_index - - # Record event - self.run_on_executor( - stream.device_index, - "eventRecord", - event_ptr.contents.value, - stream.stream_id, - ) - - @register(registry) - def destroyEvent(self, event, device_index): - self.run_on_executor(device_index, "eventDestroy", event) - self.event_belong.pop(event) - - @register(registry) - def synchronizeEvent(self, event): - self.run_on_executor(self.event_belong[event], "eventSynchronize", event) - - @register(registry) - def queryEvent(self, event): - return self.run_on_executor(self.event_belong[event], "eventQuery", event) - - @register(registry) - def elapsedTime(self, e1, e2, device_index): - return self.run_on_executor(device_index, "eventElapsedTime", e1, e2) - - @register(registry) - def block(self, event, stream): - self.run_on_executor(stream.device_index, "block", event, stream.stream_id) - - -class _Executor: - def __init__(self, id): - self.id = id - self.allocator = DeviceAllocator() - self.stream = 0 - self.event_incr_id = 0 - self.events = {} - - def run_forever(self, req_queue, ans_queue): - # Serve all requests - while True: - # Ignore stream since cpu backend doesn't support asynchronous execution - _, cmd, *args = req_queue.get() - log.info("Worker executing: %s", cmd) - if cmd in _Executor.registry: - res = _Executor.registry[cmd](self, *args) - else: - log.warning("Bad command in worker") - res = "ERROR" - - log.info("Worker answering to: %s", cmd) - ans_queue.put(res) - - registry = {} - - @register(registry) - def malloc(self, size): - return self.allocator.malloc(size) - - @register(registry) - def free(self, ptr): - return self.allocator.free(ptr) - - def _run_op(self, op_name, args, kwargs): - op, _ = torch._C._jit_get_operation(op_name) - args, kwargs = receive_after_sending(self.allocator, args, kwargs) - return op(*args, **kwargs) - - @register(registry) - def run_op(self, op_name, args, kwargs): - self._run_op(op_name, args, kwargs) - - @register(registry) - def get_op_output_shape(self, op_name, args, kwargs): - return self._run_op(op_name, args, kwargs).size() - - @register(registry) - def send_data(self, *args): - assert len(args) == 1 - return OpenRegTensorData.from_meta(self.allocator, args[0]) - - @register(registry) - def recv_data(self, host_tensor, dev_mem): - dev_tensor = OpenRegTensorData.from_meta(self.allocator, dev_mem) - dev_tensor.copy_(host_tensor) - - @register(registry) - def getNewStream(self, priority): - self.stream += 1 - return self.stream - - @register(registry) - def queryStream(self, stream): - return True - - @register(registry) - def synchronizeStream(self, stream): - # no-op - pass - - @register(registry) - def eventCreateWithFlags(self, flags): - self.event_incr_id += 1 - self.events[self.event_incr_id] = [flags, None] - return self.event_incr_id - - @register(registry) - def eventRecord(self, event, stream): - # Only flags == 1 enables timing - if self.events[event][0] == 1: - self.events[event][1] = time.time() * 1000 - return 0 - - @register(registry) - def eventDestroy(self, event): - self.events.pop(event) - - @register(registry) - def eventSynchronize(self, event): - assert self.events.get(event) is not None - return 0 - - @register(registry) - def eventQuery(self, event): - assert self.events.get(event) is not None - return True - - @register(registry) - def eventElapsedTime(self, e1, e2): - time_1 = self.events[e1][1] - time_2 = self.events[e2][1] - assert time_1 is not None and time_2 is not None - return time_2 - time_1 - - @register(registry) - def block(self, event, stream): - # no-op - pass - - -driver = Driver(NUM_DEVICES) diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py deleted file mode 100644 index 0f54f2ec4df0..000000000000 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py +++ /dev/null @@ -1,103 +0,0 @@ -import pprint - -import torch -from torch.utils._pytree import tree_map, tree_map_only - - -class OpenRegTensorMeta: - def __init__(self, tensor, checked=True): - if checked and not tensor.device.type == "openreg": - raise RuntimeError( - "Creating OpenRegTensorMeta is only for Tensors on openreg device" - ) - self.data_ptr = tensor.untyped_storage().data_ptr() - self.size = tensor.size() - self.stride = tensor.stride() - self.storage_offset = tensor.storage_offset() - self.dtype = tensor.dtype - self.nelem_in_bytes = tensor.nelement() * tensor.element_size() - - def __repr__(self): - return ( - f"OpenRegTensorMeta({self.data_ptr=}, {self.size=}, {self.stride=}, " - f"{self.storage_offset=}, {self.dtype=}, {self.nelem_in_bytes=})" - ) - - -class OpenRegTensorData(torch.Tensor): - @staticmethod - def from_meta(allocator, tensor_meta): - return OpenRegTensorData(allocator.tensor_from_meta(tensor_meta)) - - -VALID_QUEUE_TYPES_IN = {torch.Tensor, int, float} - -VALID_QUEUE_TYPES_OUT = {OpenRegTensorMeta, int, float, str} - - -def safe_str(args): - def convert(obj): - if isinstance(obj, torch.Tensor): - return str(OpenRegTensorMeta(obj, checked=False)) - else: - return obj - - new_args = tree_map(convert, args) - return pprint.pformat(new_args) - - -def validate_send_queue_args(cmd, args): - def check(obj): - if type(obj) not in VALID_QUEUE_TYPES_OUT: - if ( - cmd == "recv_data" - and type(obj) in [torch.Tensor, OpenRegTensorData] - and obj.device.type == "cpu" - ): - # Only HtoD copy command can send cpu Tensors over - return - raise RuntimeError( - f"Trying to send invalid object through queue: {type(obj)}" - ) - - tree_map(check, args) - - -def prepare_for_sending(args, kwargs): - def convert(obj): - if type(obj) not in VALID_QUEUE_TYPES_IN: - raise RuntimeError( - f"Cannot send object of type {type(obj)} over openreg device pipe." - ) - - if isinstance(obj, torch.Tensor): - return OpenRegTensorMeta(obj) - else: - return obj - - return tree_map(convert, (args, kwargs)) - - -def receive_after_sending(allocator, args, kwargs): - def convert(obj): - if type(obj) not in VALID_QUEUE_TYPES_OUT: - raise RuntimeError( - f"Received invalid object of type {type(obj)} over openreg device pipe." - ) - - if isinstance(obj, OpenRegTensorMeta): - return allocator.tensor_from_meta(obj) - else: - return obj - - return tree_map(convert, (args, kwargs)) - - -def to_device_no_copy(device, args, kwargs): - def safe_to(t): - if device == "meta": - return t.to(device=device) - else: - return torch.empty_like(t, device=device) - - return tree_map_only(torch.Tensor, safe_to, (args, kwargs)) diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/Module.cpp b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/Module.cpp deleted file mode 100644 index 4580629454b7..000000000000 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/Module.cpp +++ /dev/null @@ -1,51 +0,0 @@ -#include "OpenReg.h" - -#include - -#include -#include -#include -#include - -static PyObject* _initExtension(PyObject* self, PyObject* noargs) { - HANDLE_TH_ERRORS - - at::globalContext().lazyInitDevice(c10::DeviceType::PrivateUse1); - - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) { - HANDLE_TH_ERRORS - TORCH_CHECK( - THPUtils_checkLong(arg), - "_get_default_generator expects an int, but got ", - THPUtils_typename(arg)); - auto idx = static_cast(THPUtils_unpackLong(arg)); - - return THPGenerator_initDefaultGenerator( - at::globalContext().defaultGenerator( - c10::Device(c10::DeviceType::PrivateUse1, idx))); - - END_HANDLE_TH_ERRORS -} - -static PyMethodDef methods[] = { - {"_init", _initExtension, METH_NOARGS, nullptr}, - {"_get_default_generator", _getDefaultGenerator, METH_O, nullptr}, - {nullptr, nullptr, 0, nullptr} -}; - -static struct PyModuleDef openreg_C_module = - {PyModuleDef_HEAD_INIT, "pytorch_openreg._C", nullptr, -1, methods}; - -PyMODINIT_FUNC PyInit__C(void) { - PyObject* mod = PyModule_Create(&openreg_C_module); - - py::object openreg_mod = py::module_::import("pytorch_openreg"); - // Only borrowed from the python side! - openreg::set_impl_factory(openreg_mod.attr("impl_factory").ptr()); - - return mod; -} diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h deleted file mode 100644 index a04248f2e502..000000000000 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h +++ /dev/null @@ -1,50 +0,0 @@ -#pragma once - -#include - -namespace openreg { - -using openreg_ptr_t = uint64_t; - -void set_impl_factory(PyObject* factory); -py::function get_method(const char* name); - -static constexpr char kFreeMethod[] = "free"; -static constexpr char kHostFreeMethod[] = "hostFree"; - -template -static void ReportAndDelete(void* ptr) { - if (!ptr || !Py_IsInitialized()) { - return; - } - - py::gil_scoped_acquire acquire; - - PyObject *type = nullptr, *value = nullptr, *traceback = nullptr; - // Always stash, this will be a no-op if there is no error - PyErr_Fetch(&type, &value, &traceback); - - TORCH_CHECK( - get_method(name)(reinterpret_cast(ptr)).cast(), - "Failed to free memory pointer at ", - ptr); - - // If that user code raised an error, just print it without raising it - if (PyErr_Occurred()) { - PyErr_Print(); - } - - // Restore the original error - PyErr_Restore(type, value, traceback); -} - -#define REGISTER_PRIVATEUSE1_SERIALIZATION( \ - FOR_SERIALIZATION, FOR_DESERIALIZATION) \ - static int register_serialization() { \ - torch::jit::TensorBackendMetaRegistry( \ - c10::DeviceType::PrivateUse1, FOR_SERIALIZATION, FOR_DESERIALIZATION); \ - return 0; \ - } \ - static const int _temp = register_serialization(); - -} // namespace openreg diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp deleted file mode 100644 index a87b378fb95c..000000000000 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp +++ /dev/null @@ -1,350 +0,0 @@ -#include "OpenReg.h" - -#include -#include -#include - -#include -#include -#include - -namespace openreg { -namespace { - -// Python factory function where real implementations can be found -PyObject* py_factory; - -struct HostAllocator final : at::Allocator { - HostAllocator() = default; - - at::DataPtr allocate(size_t nbytes) override { - py::gil_scoped_acquire acquire; - void* data = nullptr; - if (nbytes > 0) { - data = reinterpret_cast( - get_method("hostMalloc")(nbytes).cast()); - TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on host."); - } - return {data, data, &ReportAndDelete, at::Device(at::kCPU)}; - } - - at::DeleterFnPtr raw_deleter() const override { - return &ReportAndDelete; - } - - void copy_data(void* dest, const void* src, std::size_t count) const final { - py::gil_scoped_acquire acquire; - get_method("hostCopyData")( - reinterpret_cast(dest), - reinterpret_cast(src), - count); - } -}; - -static HostAllocator global_host_alloc; - -static c10::DeviceIndex device_count() { - py::gil_scoped_acquire acquire; - return get_method("deviceCount")().cast(); -} - -static c10::DeviceIndex current_device_idx() { - py::gil_scoped_acquire acquire; - return get_method("getDevice")().cast(); -} - -class OpenRegGeneratorImpl : public at::CPUGeneratorImpl { - public: - OpenRegGeneratorImpl(c10::DeviceIndex device_index) { - device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index); - key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1); - } - ~OpenRegGeneratorImpl() override = default; -}; - -static at::Generator make_openreg_generator(c10::DeviceIndex device_index) { - return at::make_generator(device_index); -} - -// Default, global generators, one per device. -static std::vector default_generators; - -struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface { - OpenRegHooksInterface() {}; - ~OpenRegHooksInterface() override = default; - - bool hasPrimaryContext(c10::DeviceIndex device_index) const override { - py::gil_scoped_acquire acquire; - return get_method("hasPrimaryContext")(device_index).cast(); - } - - at::Allocator* getPinnedMemoryAllocator() const override { - return &global_host_alloc; - } - - bool isPinnedPtr(const void* data) const override { - py::gil_scoped_acquire acquire; - return get_method("isPinnedPtr")(reinterpret_cast(data)) - .cast(); - } - - const at::Generator& getDefaultGenerator( - c10::DeviceIndex device_index) const override { - static bool flag [[maybe_unused]] = []() { - auto deivce_nums = device_count(); - default_generators.resize(deivce_nums); - for (auto i = 0; i < deivce_nums; i++) { - default_generators[i] = make_openreg_generator(i); - default_generators[i].seed(); - } - return true; - }(); - - c10::DeviceIndex idx = device_index; - if (idx == -1) { - idx = current_device_idx(); - } else { - TORCH_CHECK(idx >= 0 && idx < device_count()); - } - return default_generators[idx]; - } - - at::Generator getNewGenerator(c10::DeviceIndex device_index) const override { - return make_openreg_generator(device_index); - } -}; - -static bool register_hook_flag [[maybe_unused]] = []() { - at::RegisterPrivateUse1HooksInterface(new OpenRegHooksInterface()); - - return true; -}(); - -// Device guard registration -struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface { - static constexpr c10::DeviceType static_type = c10::DeviceType::PrivateUse1; - - OpenRegGuardImpl() = default; - explicit OpenRegGuardImpl(c10::DeviceType t) { - TORCH_INTERNAL_ASSERT(t == static_type); - } - - /** - * Return the type of device managed by this guard implementation. - */ - c10::DeviceType type() const override { - return static_type; - } - - /** - * Set the current device to Device, and return the previous c10::Device. - */ - c10::Device exchangeDevice(c10::Device d) const override { - TORCH_INTERNAL_ASSERT(d.is_privateuseone()); - py::gil_scoped_acquire acquire; - auto old_device_index = - get_method("exchangeDevice")(d.index()).cast(); - return c10::Device(static_type, old_device_index); - } - - /** - * Get the current device. - */ - c10::Device getDevice() const override { - return c10::Device(static_type, current_device_idx()); - } - - /** - * Set the current device to c10::Device. - */ - void setDevice(c10::Device d) const override { - TORCH_INTERNAL_ASSERT(d.is_privateuseone()); - py::gil_scoped_acquire acquire; - auto device = get_method("setDevice")(d.index()); - } - - /** - * Set the current device to c10::Device, without checking for errors - * (so, e.g., this can be called from a destructor). - */ - void uncheckedSetDevice(c10::Device d) const noexcept override { - py::gil_scoped_acquire acquire; - auto device = get_method("uncheckedSetDevice")(d.index()); - } - - /** - * Get the current stream for a given device. - */ - c10::Stream getStream(c10::Device d) const noexcept override { - py::gil_scoped_acquire acquire; - auto stream_id = get_method("getStream")(d.index()).cast(); - return c10::Stream(c10::Stream::UNSAFE, d, stream_id); - } - - /** - * Get the default stream for a given device. - */ - c10::Stream getDefaultStream(c10::Device d) const override { - py::gil_scoped_acquire acquire; - return get_method("getDefaultStream")(d.index()).cast(); - } - - /** - * Get a stream from the global pool for a given device. - */ - c10::Stream getStreamFromGlobalPool( - c10::Device d, - bool isHighPriority = false) const override { - py::gil_scoped_acquire acquire; - return get_method("getStreamFromGlobalPool")(d.index(), isHighPriority) - .cast(); - } - - /** - * Return a new stream for a given device and priority. The stream will be - * copied and shared around, device backend should be able to correctly handle - * the lifetime of the stream. - */ - c10::Stream getNewStream(c10::Device d, int priority = 0) const override { - py::gil_scoped_acquire acquire; - auto stream_id = - get_method("getNewStream")(d.index(), priority).cast(); - return c10::Stream(c10::Stream::UNSAFE, d, stream_id); - } - - /** - * Set a stream to be the thread local current stream for its device. - * Return the previous stream for that device. You are NOT required - * to set the current device to match the device of this stream. - */ - c10::Stream exchangeStream(c10::Stream s) const noexcept override { - py::gil_scoped_acquire acquire; - auto stream_id = get_method("exchangeStream")(s).cast(); - return c10::Stream(c10::Stream::UNSAFE, s.device(), stream_id); - } - - /** - * Destroys the given event. - */ - void destroyEvent(void* event, const c10::DeviceIndex device_index) - const noexcept override { - py::gil_scoped_acquire acquire; - get_method("destroyEvent")((int64_t)event, device_index); - } - - /** - * Increments the event's version and enqueues a job with this version - * in the stream's work queue. When the stream process that job - * it notifies all streams waiting on / blocked by that version of the - * event to continue and marks that version as recorded. - * */ - void record( - void** event, - const c10::Stream& stream, - const c10::DeviceIndex device_index, - const c10::EventFlag flag) const override { - py::gil_scoped_acquire acquire; - get_method("record")((int64_t)event, stream, device_index, (int64_t)flag); - } - - /** - * Does nothing if the event has not been scheduled to be recorded. - * If the event was previously enqueued to be recorded, a command - * to wait for the version of the event that exists at the time of this call - * is inserted in the stream's work queue. - * When the stream reaches this command it will stop processing - * additional commands until that version of the event is marked as recorded. - */ - void block(void* event, const c10::Stream& stream) const override { - py::gil_scoped_acquire acquire; - get_method("block")((int64_t)event, stream); - } - - /** - * Returns true if (and only if) - * (1) the event has never been scheduled to be recorded - * (2) the current version is marked as recorded. - * Returns false otherwise. - */ - bool queryEvent(void* event) const override { - py::gil_scoped_acquire acquire; - return get_method("queryEvent")((int64_t)event).cast(); - } - - /** - * Get the number of devices. WARNING: This is REQUIRED to not raise - * an exception. If there is some sort of problem, e.g., driver error, - * you should report that there are zero available devices. - */ - c10::DeviceIndex deviceCount() const noexcept override { - return device_count(); - } - /** - * Return true if all the work previously enqueued on the stream for - * asynchronous execution has completed running on the device. - */ - bool queryStream(const c10::Stream& stream) const override { - py::gil_scoped_acquire acquire; - return get_method("queryStream")(stream).cast(); - } - - /** - * Wait (by blocking the calling thread) until all the work previously - * enqueued on the stream has completed running on the device. - */ - virtual void synchronizeStream(const c10::Stream& stream) const override { - py::gil_scoped_acquire acquire; - get_method("synchronizeStream")(stream); - } - - /** - * Wait (by blocking the calling thread) until all the work previously - * recorded on the event has completed running on the device. - */ - void synchronizeEvent(void* event) const override { - py::gil_scoped_acquire acquire; - get_method("synchronizeEvent")((int64_t)event); - } - - /** - * Ensure the caching allocator (if any) is aware that the given DataPtr is - * being used on the given stream, and that it should thus avoid recycling the - * DataPtr until all work on that stream is done. - */ - void recordDataPtrOnStream( - const c10::DataPtr& data_ptr, - const c10::Stream& stream) const override { - py::gil_scoped_acquire acquire; - get_method("recordDataPtrOnStream")(data_ptr, stream); - } - - /** - * Fetch the elapsed time between two recorded events. - */ - double elapsedTime( - void* event1, - void* event2, - const c10::DeviceIndex device_index) const override { - py::gil_scoped_acquire acquire; - return get_method("elapsedTime")( - (int64_t)event1, (int64_t)event2, device_index) - .cast(); - } -}; - -// Register our device guard -C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl); - -} // namespace - -// Setter for the python dictionary with implementations -void set_impl_factory(PyObject* factory) { - py_factory = factory; -} - -py::function get_method(const char* name) { - auto factory = py::cast(py_factory); - return factory(name); -} - -} // namespace openreg diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp deleted file mode 100644 index 4d9bde060118..000000000000 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegMem.cpp +++ /dev/null @@ -1,418 +0,0 @@ -#include "OpenReg.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -#include - -namespace openreg { -namespace { - -struct OpenRegAllocator final : at::Allocator { - OpenRegAllocator() = default; - - at::DataPtr allocate(size_t nbytes) override { - py::gil_scoped_acquire acquire; - auto curr_device_idx = get_method("getDevice")().cast(); - auto curr_device = - c10::Device(c10::DeviceType::PrivateUse1, curr_device_idx); - void* data = nullptr; - if (nbytes > 0) { - data = reinterpret_cast( - get_method("malloc")(nbytes).cast()); - TORCH_CHECK( - data, "Failed to allocator ", nbytes, " bytes on openreg device."); - } - return {data, data, &ReportAndDelete, curr_device}; - } - - at::DeleterFnPtr raw_deleter() const override { - return &ReportAndDelete; - } - - void copy_data(void* dest, const void* src, std::size_t count) const final { - py::gil_scoped_acquire acquire; - get_method("copy_data")( - reinterpret_cast(dest), - reinterpret_cast(src), - count); - } -}; - -static OpenRegAllocator global_openreg_alloc; -REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_openreg_alloc); - -// Empty op needs C++ code and cannot be handled by python side fallback -at::Tensor empty_openreg( - c10::IntArrayRef size, - std::optional dtype_opt, - std::optional layout_opt, - std::optional device_opt, - std::optional pin_memory_opt, - std::optional memory_format_opt) { - const auto device = c10::device_or_default(device_opt); - const auto dtype = c10::dtype_or_default(dtype_opt); - TORCH_CHECK(device.is_privateuseone()); - TORCH_CHECK( - c10::layout_or_default(layout_opt) == c10::Layout::Strided, - "Non strided layout not supported"); - TORCH_CHECK( - !c10::pinned_memory_or_default(pin_memory_opt), - "Pin memory can only be on CPU"); - const c10::DeviceGuard device_guard(device); - constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1); - return at::detail::empty_generic( - size, &global_openreg_alloc, pu1_dks, dtype, memory_format_opt); -} - -at::Tensor empty_strided_openreg( - c10::IntArrayRef size, - c10::IntArrayRef stride, - std::optional dtype_opt, - std::optional layout_opt, - std::optional device_opt, - std::optional pin_memory_opt) { - const auto device = c10::device_or_default(device_opt); - const auto dtype = c10::dtype_or_default(dtype_opt); - TORCH_CHECK(device.is_privateuseone()); - TORCH_CHECK( - c10::layout_or_default(layout_opt) == c10::Layout::Strided, - "Non strided layout not supported"); - TORCH_CHECK( - !c10::pinned_memory_or_default(pin_memory_opt), - "Pin memory can only be on CPU"); - const c10::DeviceGuard device_guard(device); - constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1); - return at::detail::empty_strided_generic( - size, stride, &global_openreg_alloc, pu1_dks, dtype); -} - -at::Tensor as_strided_openreg( - const at::Tensor& self, - c10::IntArrayRef size, - c10::IntArrayRef stride, - std::optional storage_offset_) { - // Metadata-only change so we re-use the cpu impl - return at::cpu::as_strided(self, size, stride, storage_offset_); -} - -const at::Tensor& resize__openreg( - const at::Tensor& self, - c10::SymIntArrayRef size, - ::std::optional memory_format) { - return at::native::resize_( - self, C10_AS_INTARRAYREF_SLOW(size), memory_format); -} - -at::Tensor& set_source_Storage_storage_offsetset_openreg( - at::Tensor& result, - at::Storage storage, - int64_t storage_offset, - c10::IntArrayRef size, - c10::IntArrayRef stride) { - return at::cpu::set_(result, storage, storage_offset, size, stride); -} - -std::tuple -custom_scaled_dot_product_fused_attention_overrideable( - const at::Tensor & query, - const at::Tensor & key, - const at::Tensor & value, - const std::optional & attn_bias, - double dropout_p, - bool is_causal, - bool return_debug_mask, - std::optional scale) { - const int64_t batch_size = query.size(0); - const int64_t num_heads = query.size(1); - const int64_t head_dim_v = value.size(3); - const int64_t max_seqlen_q = query.size(2); - const int64_t max_seqlen_kv = key.size(2); - - auto opts = query.options(); - auto output = at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts); - auto logsumexp = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); - auto debug_attn_mask = at::empty({batch_size, num_heads, max_seqlen_q, max_seqlen_kv}, - opts.dtype(at::kFloat)); - auto philox_seed = at::empty({}, at::dtype(at::kLong)); - auto philox_offset = at::empty({}, at::dtype(at::kLong)); - - return std::make_tuple(output, logsumexp, at::Tensor(), at::Tensor(), max_seqlen_q, max_seqlen_kv, philox_seed, philox_offset, debug_attn_mask); -} - -std::tuple -custom_scaled_dot_product_fused_attention_overrideable_backward( - const at::Tensor & grad_out, - const at::Tensor & query, - const at::Tensor & key, - const at::Tensor & value, - const at::Tensor & attn_bias, - std::array grad_input_mask, - const at::Tensor & out, - const at::Tensor & logsumexp, - const at::Tensor & cum_seq_q, - const at::Tensor & cum_seq_k, - int64_t max_q, - int64_t max_k, - double dropout_p, - bool is_causal, - const at::Tensor & philox_seed, - const at::Tensor & philox_offset, - std::optional scale) { - return std::tuple( - at::empty_like(query), - at::empty_like(key), - at::empty_like(value), - at::empty_like(attn_bias)); -} -} - -// Using the simplest way to obtain continuous Tensor data and process it. -// This is a demo for using operand API, and you can add more complex logic -// for input and output tensor based on your custom device kernel. -void abs_kernel(at::TensorIteratorBase& iter) { - // Abs only have a input tensor and a output tensor. - auto& output_operand = iter.operand(0); - auto& input_operand = iter.operand(1); - auto& output_tensor_base = output_operand.tensor_base(); - auto& input_tensor_base = input_operand.tensor_base(); - TORCH_CHECK(!input_operand.original_tensor_base().defined(), - "input original tensor is defined."); - TORCH_CHECK(!output_operand.original_tensor_base().defined(), - "output original tensor is defined."); - // For easy test, only accept contiguous input tensor for calculate. - auto memory_format = input_tensor_base.suggest_memory_format(); - TORCH_CHECK(input_tensor_base.is_contiguous(memory_format), - "Input tensor need be contiguous."); - // Add necessary restrictions to ensure the security of the demo. - TORCH_CHECK(input_tensor_base.sizes() == output_tensor_base.sizes(), - "Intput and output tensor size are not equal."); - // Common dtype is calculate in TensorIteratorBase. - TORCH_CHECK(iter.common_dtype() == at::ScalarType::Float, - "Only support float type.") - // Using for loop for abs calculate. - auto abs_function = [](float* output_ptr, const float* input_ptr, - const int64_t NUM) { - for (int64_t i = 0; i < NUM; ++i) { - *(output_ptr + i) = std::abs(*(input_ptr + i)); - } - }; - // To simplify the logic of the test demo code, - // we only use contiguous tensor to calculate on device side. - // And using input tensor memory format. - if (iter.is_contiguous()) { - // Add for will_resize flag check. You can convert to differernt - // tensor memory format when will_resize is True. - // If TensorIteratorConfig resize_outputs_ flag is true, and there are two - // situations: - // 1) Out tensor is undefined, and TensorIterator set will_resize to true; - // 2) Out tensor is defined and tensor size is not equal to input tensor size; - // TensorIterator set will_resize to true, and call set_output_raw_strided - // to resize output tensor. - // When output operand will_resize flag is ture, dummy - // device can convert tensor to dummy device preferred memory format. - // Here we don't convert tensor memory format, because it will become complex - // when dummy device want keep same memory format for training network. - TORCH_CHECK(output_operand.will_resize, - "output operand will_resize flag need be True."); - abs_function((float*)iter.data_ptr(0), (float*)iter.data_ptr(1), iter.numel()); - } else { - // Stride copy is not support for foo device, using cpu device instead. - // For abs op, the last situation is: output tensor is not contiguous with - // operand will_resize is False. - TORCH_CHECK(!output_operand.will_resize, "output operand will_resize is True."); - // Get a contiguous tensor with input memory format. - at::Tensor output = at::empty(output_tensor_base.sizes(), - input_tensor_base.options() - .memory_format(memory_format)); - // For structured op which inheried from TensorIteratorBase, maybe you need to - // call set_output_raw_strided function to update output stored in op sturctured. - // abs op is no need to do this. - output_operand.exchange_tensor(c10::MaybeOwned::owned(std::in_place, output)); - abs_function((float*)output_operand.tensor_base().mutable_data_ptr(), - (float*)iter.data_ptr(1), iter.numel()); - // Copy tensor base to original tensor base, and keep same scalar type and - // stride with cpu and gpu. - if (output_operand.original_tensor_base().defined() && - !output_operand.original_tensor_base().is_same(output_operand.tensor_base())) { - output_operand.original_tensor().copy_(output_operand.tensor()); - output_operand.restore_original_tensor(); - } - } -} - -int64_t _fused_sdp_choice_privateuse1( - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const std::optional& attn_mask, - double dropout_p, - bool is_causal, - std::optional scale, - bool enable_gqa) { - auto backend = sdp::SDPBackend::overrideable; - return static_cast(backend); -} - -void quantize_tensor_per_tensor_affine_privateuse1( - const at::Tensor& rtensor, - at::Tensor& qtensor, - double scale, - int64_t zero_point) { - // Just test the process, so do nothing -} - -struct CustomAutogradFnReturnsSelf - : public torch::autograd::Function { - static at::Tensor forward( - torch::autograd::AutogradContext* ctx, - at::Tensor self) { - return self; - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_output) { - return {grad_output[0] * 0.5}; - } -}; - -struct CustomAutogradFnAliasing - : public torch::autograd::Function { - static at::Tensor forward( - torch::autograd::AutogradContext* ctx, - at::Tensor self) { - return self.view_symint(self.sym_sizes()); - } - - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_output) { - return {grad_output[0] * 0.5}; - } -}; - -at::Tensor custom_autograd_fn_returns_self(at::Tensor x) { - return CustomAutogradFnReturnsSelf::apply(x); -} - -at::Tensor custom_autograd_fn_aliasing(at::Tensor x) { - return CustomAutogradFnAliasing::apply(x); -} - -/* Notes: - * - * OpenReg is currently designed to simulate device memory through multiple - * subprocesses on purpose to ensure we don't mistakenly poke at the "device's - * memory" from the main process. And be able to simulate the same thing that - * happens with other accelerators: any metadata-only change is cpu-only - * (main process), any data change must go through to the device (other process) - * and any data transfer between the two is expensive (serializing the whole - * Tensor). - * - * Currently, for the efficiency of IPC, most operations are to pass the Tensor - * metadata, and only a small number of operations involving copy will serialize - * and pass the Tensor body by custom pickler provided by torch.multiprocess. - * - * Therefore, in principle, only operations related to Metadata modification can - * be directly implemented at the C++ level and registered in PrivateUse1; but - * if memory access is involved, the relevant operations must be implemented at - * the Python level, otherwise invalid memory access will result. - */ - -TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { - m.impl("empty.memory_format", empty_openreg); - m.impl("empty_strided", empty_strided_openreg); - m.impl("as_strided", as_strided_openreg); - m.impl("resize_", resize__openreg); - m.impl("set_.source_Storage", at::native::set_); - m.impl("set_.source_Storage_storage_offset", set_source_Storage_storage_offsetset_openreg); - m.impl("quantize_per_tensor", at::native::quantize_per_tensor); - m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1); - m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable); - m.impl("_scaled_dot_product_fused_attention_overrideable_backward", &custom_scaled_dot_product_fused_attention_overrideable_backward); -} - -struct OpenRegBackendMeta : public c10::BackendMeta { - OpenRegBackendMeta(int version_number, int format_number) - : version_number_(version_number), format_number_(format_number) {} - - int version_number_{-1}; - int format_number_{-1}; -}; - -void for_serialization( - const at::Tensor& t, - std::unordered_map& m) { - auto meta_ptr = t.unsafeGetTensorImpl()->get_backend_meta(); - - if (meta_ptr != nullptr) { - auto o_meta_ptr = dynamic_cast(meta_ptr); - if (o_meta_ptr->version_number_ == 1) { - m["version_number"] = true; - } - if (o_meta_ptr->format_number_ == 29) { - m["format_number"] = true; - } - } -} - -void for_deserialization( - const at::Tensor& t, - std::unordered_map& m) { - int version_number{-1}; - int format_number{-1}; - - if (m.find("version_number") != m.end()) { - version_number = 1; - } - if (m.find("format_number") != m.end()) { - format_number = 29; - } - - c10::intrusive_ptr meta{std::unique_ptr( - new OpenRegBackendMeta(version_number, format_number))}; - t.unsafeGetTensorImpl()->set_backend_meta(meta); -} - -REGISTER_PRIVATEUSE1_SERIALIZATION(&for_serialization, &for_deserialization) -} // namespace openreg - -namespace at::native { -REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &openreg::abs_kernel); -REGISTER_PRIVATEUSE1_DISPATCH( - quantize_tensor_per_tensor_affine_stub, - &openreg::quantize_tensor_per_tensor_affine_privateuse1); -REGISTER_PRIVATEUSE1_DISPATCH( - _fused_sdp_choice_stub, - &openreg::_fused_sdp_choice_privateuse1); -} // namespace at::native - -TORCH_LIBRARY(openreg, m) { - m.def("custom_autograd_fn_returns_self(Tensor input)-> Tensor"); - m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)"); -} - -TORCH_LIBRARY_IMPL(openreg, AutogradPrivateUse1, m) { - m.impl("custom_autograd_fn_aliasing", &openreg::custom_autograd_fn_aliasing); - m.impl( - "custom_autograd_fn_returns_self", - &openreg::custom_autograd_fn_returns_self); -} diff --git a/test/cpp_extensions/open_registration_extension/setup.py b/test/cpp_extensions/open_registration_extension/setup.py deleted file mode 100644 index fa8c1308c6c5..000000000000 --- a/test/cpp_extensions/open_registration_extension/setup.py +++ /dev/null @@ -1,78 +0,0 @@ -import distutils.command.clean -import os -import platform -import shutil -import sys -from pathlib import Path - -from setuptools import find_packages, setup - -from torch.utils.cpp_extension import BuildExtension, CppExtension - - -PACKAGE_NAME = "pytorch_openreg" -version = 1.0 - -ROOT_DIR = Path(__file__).absolute().parent -CSRS_DIR = ROOT_DIR / "pytorch_openreg/csrc" - - -class clean(distutils.command.clean.clean): - def run(self): - # Run default behavior first - distutils.command.clean.clean.run(self) - - # Remove pytorch_openreg extension - for path in (ROOT_DIR / "pytorch_openreg").glob("**/*.so"): - path.unlink() - # Remove build directory - build_dirs = [ - ROOT_DIR / "build", - ] - for path in build_dirs: - if path.exists(): - shutil.rmtree(str(path), ignore_errors=True) - - -if __name__ == "__main__": - if sys.platform == "win32": - vc_version = os.getenv("VCToolsVersion", "") - if vc_version.startswith("14.16."): - CXX_FLAGS = ["/sdl"] - else: - CXX_FLAGS = ["/sdl", "/permissive-"] - elif platform.machine() == "s390x": - # no -Werror on s390x due to newer compiler - CXX_FLAGS = {"cxx": ["-g", "-Wall"]} - else: - CXX_FLAGS = {"cxx": ["-g", "-Wall", "-Werror"]} - - sources = list(CSRS_DIR.glob("*.cpp")) - - # Note that we always compile with debug info - ext_modules = [ - CppExtension( - name="pytorch_openreg._C", - sources=sorted(str(s) for s in sources), - include_dirs=[CSRS_DIR], - extra_compile_args=CXX_FLAGS, - ) - ] - - setup( - name=PACKAGE_NAME, - version=version, - author="PyTorch Core Team", - description="Example for PyTorch out of tree registration", - packages=find_packages(exclude=("test",)), - package_data={PACKAGE_NAME: ["*.dll", "*.dylib", "*.so"]}, - install_requires=[ - "torch", - ], - ext_modules=ext_modules, - python_requires=">=3.8", - cmdclass={ - "build_ext": BuildExtension.with_options(no_python_abi_suffix=True), - "clean": clean, - }, - ) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/CMakeLists.txt b/test/cpp_extensions/open_registration_extension/torch_openreg/CMakeLists.txt new file mode 100644 index 000000000000..73163b8cb1ae --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/CMakeLists.txt @@ -0,0 +1,38 @@ +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) + +project(TORCH_OPENREG CXX C) + +include(GNUInstallDirs) +include(CheckCXXCompilerFlag) +include(CMakeDependentOption) + +set(CMAKE_SKIP_BUILD_RPATH FALSE) +set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) +set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE) +set(CMAKE_INSTALL_RPATH "$ORIGIN/lib/:$ORIGIN/") + +set(LINUX TRUE) +set(CMAKE_INSTALL_MESSAGE NEVER) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_EXTENSIONS OFF) + +set(CMAKE_INSTALL_LIBDIR lib) + +add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1) + +set(Torch_DIR ${PYTORCH_INSTALL_DIR}/share/cmake/Torch) +find_package(Torch REQUIRED) +include_directories(${PYTORCH_INSTALL_DIR}/include) + +if(DEFINED PYTHON_INCLUDE_DIR) + include_directories(${PYTHON_INCLUDE_DIR}) +else() + message(FATAL_ERROR "Cannot find Python directory") +endif() + +add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/openreg) +add_subdirectory(${PROJECT_SOURCE_DIR}/csrc) +add_subdirectory(${PROJECT_SOURCE_DIR}/torch_openreg/csrc) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/README.md b/test/cpp_extensions/open_registration_extension/torch_openreg/README.md new file mode 100644 index 000000000000..e59013cea440 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/README.md @@ -0,0 +1,177 @@ +# PyTorch OpenReg + +## Background + +The third-party device integration mechanism based on PrivateUse1 has become the official mainstream method for new backends to integrate with PyTorch. Ensuring the availability of this mechanism is crucial for enriching PyTorch's hardware ecosystem. + +**Note:** + +The goal of `torch_openreg` is **not to implement a fully functional, high-performance PyTorch backend**, but to serve as a **minimalist reference implementation for mechanism verification**. + +### Purpose + +- **Test Backend**: To serve as an in-tree test backend for PrivateUse1, ensuring quality stability through CI/CD. +- **Integration Example**: To serve as a reference example for new backend integration. +- **Integration Documentation**: To provide module-level integration documentation that corresponds with the code. + +### Design Principles + +- **Minimality Principle**: The fundamental goal is to enable/verify all integration paths/mechanisms for a new backend to integrate to PyTorch. All functions follow a "just right" strategy to ensure the correctness of relevant integration capabilities. +- **Authenticity Principle**: To complete the OpenReg integration in the same way a real accelerator backend would integrate with PyTorch. + +## Directory Structure + +```shell +torch_openreg/ +├── CMakeLists.txt +├── csrc +│ ├── aten +│ │ ├── native +│ │ │ ├── Extra.cpp +│ │ │ ├── Minimal.cpp +│ │ │ └── ... +│ │ ├── OpenRegExtra.cpp +│ │ └── OpenRegMinimal.cpp +│ ├── CMakeLists.txt +│ └── runtime +│ ├── OpenRegDeviceAllocator.cpp +│ ├── OpenRegDeviceAllocator.h +│ ├── OpenRegFunctions.cpp +│ ├── OpenRegFunctions.h +│ ├── OpenRegGenerator.cpp +│ ├── OpenRegGenerator.h +│ ├── OpenRegGuard.cpp +│ ├── OpenRegGuard.h +│ ├── OpenRegHooks.cpp +│ ├── OpenRegHooks.h +│ ├── OpenRegHostAllocator.cpp +│ ├── OpenRegHostAllocator.h +│ └── ... +├── README.md +├── setup.py +├── third_party +│ └── openreg +└── torch_openreg + ├── csrc + │ ├── CMakeLists.txt + │ ├── Module.cpp + │ └── stub.c + ├── __init__.py + └── openreg + ├── __init__.py + └── random.py +``` + +**Dependencies**: + +```mermaid +graph LR + A[Python] + B[_C.so] + C[libtorch_bindings.so] + D[libtorch_openreg.so] + E[libopenreg.so] + + A --> B --> C --> D --> E +``` + +- `_C.so`: torch\_openreg/csrc/stub.c +- `libtorch_bindings.so`: torch\_openreg/csrc/\*.cpp +- `libtorch_openreg.so`: csrc +- `libopenreg.so`: third\_party/openreg + +**Key Directories**: + +- `csrc/`: Core device implementation, including operator registration, runtime, etc. + - `csrc/aten/`: Operator registration + - `csrc/aten/native/`: Specific operator implementations for the OpenReg device. + - `csrc/aten/OpenRegMinimal.cpp`: The most minimal set of operator implementations (allowing for the creation of Tensors and related operations upon completion). + - `csrc/aten/OpenRegExtra.cpp`: Implementations for other types of operators. + - `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc. +- `third_party/`: A C++ library that simulates a CUDA-like device using the CPU. +- `torch_openreg/`: Python interface implementation (Python code and C++ Bindings). + - `torch_openreg/csrc/`: Python C++ binding code. + - `torch_openreg/openreg/`: Python API. + +## Currently Implemented Features + +### Operator Registration + +- Operator Implementation + + - `TORCH_LIBRARY` form + - Registering a specific operator for an existing schema: See `empty.memory_format` + - Registering an operator with a custom schema + - Extending an existing namespace: (TODO) + - Custom namespace: See `custom_autograd_fn_returns_self` + - Autograd: See `custom_autograd_fn_returns_self` + - STUB form: See `abs_stub` + + - Fallback + - Global Fallback: See `wrapper_cpu_fallback` + - Per-operator Fallback: (TODO) + + - AMP (TODO) + +### Memory Management + +- Device Memory Management (TODO) +- Host Memory Management (TODO) + +### Custom Storage + +- Adding custom device descriptions (TODO) +- Serialization support (TODO) + +### Autoload + +- (TODO) + +... + +## Installation and Usage + +### Installation + +```python +pip3 install -r requirements.txt + +python setup.py develop/install +``` + +### Usage Example + +After installation, you can use the `openreg` device in Python just like any other regular device. + +```python +import torch +import torch_openreg + +if not torch.openreg.is_available(): + print("OpenReg backend is not available in this build.") + exit() + +print("OpenReg backend is available!") + +device = torch.device("openreg") + +try: + x = torch.tensor([[1., 2.], [3., 4.]], device=device) + y = x + 2 + print("Result y:\n", y) + print(f"Device of y: {y.device}") + + z = y.cpu() + print("Result z:\n", z) + print(f"Device of z: {z.device}") + +except Exception as e: + print(f"\nAn error occurred: {e}") +``` + +## Future Plans + +- **Enhance Features**: AMP, memory management, generators, distributed computing, etc. (to reiterate, the fundamental goal is to verify the integration mechanism). +- **Improve Tests**: Add more test cases related to the integration mechanism. +- **Improve Documentation**: Add a new chapter on third-party device integration in the `Developer Notes` section of the PyTorch documentation. +- **Real-time Synchronization**: Keep the code and documentation updated iteratively and in sync. diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/CMakeLists.txt b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/CMakeLists.txt new file mode 100644 index 000000000000..077f4cf3b640 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/CMakeLists.txt @@ -0,0 +1,12 @@ +set(LIBRARY_NAME torch_openreg) + +file(GLOB_RECURSE SOURCE_FILES + "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" +) + +add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES}) + +target_link_libraries(${LIBRARY_NAME} PRIVATE openreg torch_cpu) +target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + +install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp new file mode 100644 index 000000000000..3d8525697cc8 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp @@ -0,0 +1,138 @@ +#include "native/Extra.h" + +#include +#include + +#include + +namespace at::openreg { + +at::Tensor wrapper_quantize_per_tensor( + const at::Tensor& self, + double scale, + int64_t zero_point, + at::ScalarType dtype) { + return at::native::quantize_per_tensor_openreg( + self, scale, zero_point, dtype); +} + +int64_t wrapper__fused_sdp_choice( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + bool enable_gqa) { + return at::native::_fused_sdp_choice_openreg( + query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa); +} + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + c10::SymInt, + c10::SymInt, + at::Tensor, + at::Tensor, + at::Tensor> +wrapper__scaled_dot_product_fused_attention_overrideable( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_bias, + double dropout_p, + bool is_causal, + bool return_debug_mask, + std::optional scale) { + return at::native::_scaled_dot_product_fused_attention_overrideable_openreg( + query, + key, + value, + attn_bias, + dropout_p, + is_causal, + return_debug_mask, + scale); +} + +std::tuple +wrapper_scaled_dot_product_fused_attention_overrideable_backward( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& attn_bias, + std::array grad_input_mask, + const at::Tensor& out, + const at::Tensor& logsumexp, + const at::Tensor& cum_seq_q, + const at::Tensor& cum_seq_k, + int64_t max_q, + int64_t max_k, + double dropout_p, + bool is_causal, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + std::optional scale) { + return at::native:: + _scaled_dot_product_fused_attention_overrideable_backward_openreg( + grad_out, + query, + key, + value, + attn_bias, + grad_input_mask, + out, + logsumexp, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + dropout_p, + is_causal, + philox_seed, + philox_offset, + scale); +} + +TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { + m.impl("quantize_per_tensor", &wrapper_quantize_per_tensor); + m.impl("_fused_sdp_choice", &wrapper__fused_sdp_choice); + m.impl( + "_scaled_dot_product_fused_attention_overrideable", + &wrapper__scaled_dot_product_fused_attention_overrideable); + m.impl( + "_scaled_dot_product_fused_attention_overrideable_backward", + &wrapper_scaled_dot_product_fused_attention_overrideable_backward); +} + +} // namespace at::openreg + +namespace at::openreg { +TORCH_LIBRARY(openreg, m) { + m.def("custom_autograd_fn_returns_self(Tensor input)-> Tensor"); + m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)"); +} + +TORCH_LIBRARY_IMPL(openreg, AutogradPrivateUse1, m) { + m.impl( + "custom_autograd_fn_returns_self", + &at::native::custom_autograd_fn_returns_self); + m.impl( + "custom_autograd_fn_aliasing", &at::native::custom_autograd_fn_aliasing); +} +} // namespace at::openreg + +namespace at::native { +REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel_openreg); +REGISTER_PRIVATEUSE1_DISPATCH( + quantize_tensor_per_tensor_affine_stub, + &quantize_tensor_per_tensor_affine_stub_openreg); +REGISTER_PRIVATEUSE1_DISPATCH( + _fused_sdp_choice_stub, + &_fused_sdp_choice_openreg); +} // namespace at::native diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp new file mode 100644 index 000000000000..fe75cdaea8b2 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegMinimal.cpp @@ -0,0 +1,128 @@ +#include "native/Minimal.h" + +#include +#include + +#include + +namespace at::openreg { + +at::Tensor wrapper_empty_memory_format( + c10::IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt) { + return at::native::empty_memory_format_openreg( + size, + dtype_opt, + layout_opt, + device_opt, + pin_memory_opt, + memory_format_opt); +} + +at::Tensor wrapper_empty_strided( + c10::IntArrayRef size, + c10::IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt) { + return at::native::empty_strided_openreg( + size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); +} + +at::Tensor wrapper_as_strided( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride, + std::optional storage_offset) { + return at::native::as_strided_openreg(self, size, stride, storage_offset); +} + +const at::Tensor& wrapper_resize_( + const at::Tensor& self, + c10::SymIntArrayRef size, + ::std::optional memory_format) { + return at::native::resize_openreg_(self, size, memory_format); +} + +at::Tensor wrapper__reshape_alias( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride) { + return at::native::_reshape_alias_openreg(self, size, stride); +} + +at::Tensor wrapper__copy_from( + const at::Tensor& self, + const at::Tensor& dst, + bool non_blocking) { + return at::native::_copy_from_openreg(self, dst, non_blocking); +} + +at::Tensor wrapper__copy_from_and_resize( + const at::Tensor& self, + const at::Tensor& dst) { + return at::native::_copy_from_and_resize_openreg(self, dst); +} + +at::Scalar wrapper__local_scalar_densor(const at::Tensor& self) { + return at::native::_local_scalar_dense_openreg(self); +} + +at::Tensor& wrapper_set_source_Tensor_( + at::Tensor& self, + const at::Tensor& source) { + return at::native::set_source_Tensor_openreg_(self, source); +} + +at::Tensor& wrapper_set_source_Storage_(at::Tensor& self, at::Storage source) { + return at::native::set_source_Storage_openreg_(self, source); +} + +at::Tensor& wrapper_set_source_Storage_storage_offsetset_( + at::Tensor& result, + at::Storage storage, + int64_t storage_offset, + c10::IntArrayRef size, + c10::IntArrayRef stride) { + return at::native::set_source_Storage_storage_offset_openreg_( + result, storage, storage_offset, size, stride); +} + +at::Tensor wrapper_view(const at::Tensor& self, c10::SymIntArrayRef size) { + return at::native::view_openreg(self, size); +} + +TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { + m.impl("empty.memory_format", wrapper_empty_memory_format); + m.impl("empty_strided", wrapper_empty_strided); + m.impl("as_strided", wrapper_as_strided); + m.impl("resize_", wrapper_resize_); + m.impl("_reshape_alias", wrapper__reshape_alias); + m.impl("_copy_from", wrapper__copy_from); + m.impl("_copy_from_and_resize", wrapper__copy_from_and_resize); + m.impl("_local_scalar_dense", wrapper__local_scalar_densor); + m.impl("set_.source_Tensor", wrapper_set_source_Tensor_); + m.impl("set_.source_Storage", wrapper_set_source_Storage_); + m.impl( + "set_.source_Storage_storage_offset", + wrapper_set_source_Storage_storage_offsetset_); + m.impl("view", wrapper_view); +} + +void wrapper_cpu_fallback( + const c10::OperatorHandle& op, + torch::jit::Stack* stack) { + at::native::cpu_fallback_openreg(op, stack); +} + +TORCH_LIBRARY_IMPL(_, PrivateUse1, m) { + m.fallback( + torch::CppFunction::makeFromBoxedFunction<&wrapper_cpu_fallback>()); +} + +} // namespace at::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Common.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Common.h new file mode 100644 index 000000000000..a706137fe852 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Common.h @@ -0,0 +1,106 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include + +#include + +namespace at::native { + +class MemoryGuard { + public: + explicit MemoryGuard(const torch::jit::Stack& stack) { + for (const c10::IValue& ivalue : stack) { + find_and_unprotect_tensors(ivalue); + } + } + + template + explicit MemoryGuard(const Args&... args) { + (handler(args), ...); + } + + ~MemoryGuard() { + for (void* ptr : unprotected_pointers_) { + orMemoryProtect(ptr); + } + } + + MemoryGuard(const MemoryGuard&) = delete; + MemoryGuard& operator=(const MemoryGuard&) = delete; + MemoryGuard(MemoryGuard&&) = delete; + MemoryGuard& operator=(MemoryGuard&&) = delete; + + private: + void find_and_unprotect_tensors(const c10::IValue& ivalue) { + if (ivalue.isTensor()) { + unprotect_if_needed(ivalue.toTensor()); + } else if (ivalue.isTensorList()) { + for (const at::Tensor& tensor : ivalue.toTensorList()) { + unprotect_if_needed(tensor); + } + } else if (ivalue.isList()) { + for (const c10::IValue& element : ivalue.toListRef()) { + find_and_unprotect_tensors(element); + } + } else if (ivalue.isGenericDict()) { + for (const auto& pair : ivalue.toGenericDict()) { + find_and_unprotect_tensors(pair.key()); + find_and_unprotect_tensors(pair.value()); + } + } + } + + void unprotect_if_needed(const at::Tensor& tensor) { + if (!tensor.defined() || !tensor.has_storage()) { + return; + } + + void* ptr = tensor.data_ptr(); + orPointerAttributes attr; + + if (orPointerGetAttributes(&attr, ptr) == orSuccess) { + if (attr.type == orMemoryTypeDevice) { + if (unprotected_pointers_.find(attr.pointer) == + unprotected_pointers_.end()) { + orMemoryUnprotect(attr.pointer); + unprotected_pointers_.insert(attr.pointer); + } + } + } + } + + template + void handler(const T& x) { + if constexpr (std::is_same_v, at::Tensor>) { + unprotect_if_needed(x); + } + } + + std::set unprotected_pointers_; +}; + +} // namespace at::native diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.cpp new file mode 100644 index 000000000000..741d14803539 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.cpp @@ -0,0 +1,238 @@ +#include "Extra.h" + +namespace at::native { + +at::Tensor quantize_per_tensor_openreg( + const at::Tensor& self, + double scale, + int64_t zero_point, + at::ScalarType dtype) { + return at::native::quantize_per_tensor(self, scale, zero_point, dtype); +} + +int64_t _fused_sdp_choice_openreg( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + bool enable_gqa) { + auto backend = sdp::SDPBackend::overrideable; + return static_cast(backend); +} + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + c10::SymInt, + c10::SymInt, + at::Tensor, + at::Tensor, + at::Tensor> +_scaled_dot_product_fused_attention_overrideable_openreg( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_bias, + double dropout_p, + bool is_causal, + bool return_debug_mask, + std::optional scale) { + const int64_t batch_size = query.size(0); + const int64_t num_heads = query.size(1); + const int64_t head_dim_v = value.size(3); + const int64_t max_seqlen_q = query.size(2); + const int64_t max_seqlen_kv = key.size(2); + + auto opts = query.options(); + auto output = + at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts); + auto logsumexp = + at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + auto debug_attn_mask = at::empty( + {batch_size, num_heads, max_seqlen_q, max_seqlen_kv}, + opts.dtype(at::kFloat)); + auto philox_seed = at::empty({}, at::dtype(at::kLong)); + auto philox_offset = at::empty({}, at::dtype(at::kLong)); + + return std::make_tuple( + output, + logsumexp, + at::Tensor(), + at::Tensor(), + max_seqlen_q, + max_seqlen_kv, + philox_seed, + philox_offset, + debug_attn_mask); +} + +std::tuple +_scaled_dot_product_fused_attention_overrideable_backward_openreg( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& attn_bias, + std::array grad_input_mask, + const at::Tensor& out, + const at::Tensor& logsumexp, + const at::Tensor& cum_seq_q, + const at::Tensor& cum_seq_k, + int64_t max_q, + int64_t max_k, + double dropout_p, + bool is_causal, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + std::optional scale) { + return std::tuple( + at::empty_like(query), + at::empty_like(key), + at::empty_like(value), + at::empty_like(attn_bias)); +} + +} // namespace at::native + +namespace at::native { + +void abs_kernel_openreg(at::TensorIteratorBase& iter) { + // Abs only have a input tensor and a output tensor. + auto& output_operand = iter.operand(0); + auto& input_operand = iter.operand(1); + auto& output_tensor_base = output_operand.tensor_base(); + auto& input_tensor_base = input_operand.tensor_base(); + TORCH_CHECK( + !input_operand.original_tensor_base().defined(), + "input original tensor is defined."); + TORCH_CHECK( + !output_operand.original_tensor_base().defined(), + "output original tensor is defined."); + // For easy test, only accept contiguous input tensor for calculate. + auto memory_format = input_tensor_base.suggest_memory_format(); + TORCH_CHECK( + input_tensor_base.is_contiguous(memory_format), + "Input tensor need be contiguous."); + // Add necessary restrictions to ensure the security of the demo. + TORCH_CHECK( + input_tensor_base.sizes() == output_tensor_base.sizes(), + "Intput and output tensor size are not equal."); + // Common dtype is calculate in TensorIteratorBase. + TORCH_CHECK( + iter.common_dtype() == at::ScalarType::Float, "Only support float type.") + // Using for loop for abs calculate. + auto abs_function = + [](float* output_ptr, const float* input_ptr, const int64_t NUM) { + for (int64_t i = 0; i < NUM; ++i) { + *(output_ptr + i) = std::abs(*(input_ptr + i)); + } + }; + // To simplify the logic of the test demo code, + // we only use contiguous tensor to calculate on device side. + // And using input tensor memory format. + if (iter.is_contiguous()) { + // Add for will_resize flag check. You can convert to differernt + // tensor memory format when will_resize is True. + // If TensorIteratorConfig resize_outputs_ flag is true, and there are two + // situations: + // 1) Out tensor is undefined, and TensorIterator set will_resize to true; + // 2) Out tensor is defined and tensor size is not equal to input tensor + // size; + // TensorIterator set will_resize to true, and call + // set_output_raw_strided to resize output tensor. + // When output operand will_resize flag is ture, dummy + // device can convert tensor to dummy device preferred memory format. + // Here we don't convert tensor memory format, because it will become + // complex when dummy device want keep same memory format for training + // network. + TORCH_CHECK( + output_operand.will_resize, + "output operand will_resize flag need be True."); + abs_function( + (float*)iter.data_ptr(0), (float*)iter.data_ptr(1), iter.numel()); + } else { + // Stride copy is not support for foo device, using cpu device instead. + // For abs op, the last situation is: output tensor is not contiguous with + // operand will_resize is False. + TORCH_CHECK( + !output_operand.will_resize, "output operand will_resize is True."); + // Get a contiguous tensor with input memory format. + at::Tensor output = at::empty( + output_tensor_base.sizes(), + input_tensor_base.options().memory_format(memory_format)); + // For structured op which inheried from TensorIteratorBase, maybe you need + // to call set_output_raw_strided function to update output stored in op + // sturctured. abs op is no need to do this. + output_operand.exchange_tensor( + c10::MaybeOwned::owned(std::in_place, output)); + abs_function( + (float*)output_operand.tensor_base().mutable_data_ptr(), + (float*)iter.data_ptr(1), + iter.numel()); + // Copy tensor base to original tensor base, and keep same scalar type and + // stride with cpu and gpu. + if (output_operand.original_tensor_base().defined() && + !output_operand.original_tensor_base().is_same( + output_operand.tensor_base())) { + output_operand.original_tensor().copy_(output_operand.tensor()); + output_operand.restore_original_tensor(); + } + } +} + +void quantize_tensor_per_tensor_affine_stub_openreg( + const at::Tensor& rtensor, + at::Tensor& qtensor, + double scale, + int64_t zero_point) {} + +} // namespace at::native + +namespace at::native { + +namespace { +struct CustomAutogradFnReturnsSelf + : public torch::autograd::Function { + static at::Tensor forward( + torch::autograd::AutogradContext* ctx, + at::Tensor self) { + return self; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + return {grad_output[0] * 0.5}; + } +}; + +struct CustomAutogradFnAliasing + : public torch::autograd::Function { + static at::Tensor forward( + torch::autograd::AutogradContext* ctx, + at::Tensor self) { + return self.view_symint(self.sym_sizes()); + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + return {grad_output[0] * 0.5}; + } +}; +} // namespace + +at::Tensor custom_autograd_fn_returns_self(at::Tensor x) { + return CustomAutogradFnReturnsSelf::apply(x); +} + +at::Tensor custom_autograd_fn_aliasing(at::Tensor x) { + return CustomAutogradFnAliasing::apply(x); +} + +} // namespace at::native diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.h new file mode 100644 index 000000000000..95109cd3fa33 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Extra.h @@ -0,0 +1,70 @@ +#include "Common.h" + +namespace at::native { +at::Tensor quantize_per_tensor_openreg( + const at::Tensor& self, + double scale, + int64_t zero_point, + at::ScalarType dtype); +int64_t _fused_sdp_choice_openreg( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + bool enable_gqa); +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + c10::SymInt, + c10::SymInt, + at::Tensor, + at::Tensor, + at::Tensor> +_scaled_dot_product_fused_attention_overrideable_openreg( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_bias, + double dropout_p, + bool is_causal, + bool return_debug_mask, + std::optional scale); +std::tuple +_scaled_dot_product_fused_attention_overrideable_backward_openreg( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& attn_bias, + std::array grad_input_mask, + const at::Tensor& out, + const at::Tensor& logsumexp, + const at::Tensor& cum_seq_q, + const at::Tensor& cum_seq_k, + int64_t max_q, + int64_t max_k, + double dropout_p, + bool is_causal, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + std::optional scale); +} // namespace at::native + +namespace at::native { +void abs_kernel_openreg(at::TensorIteratorBase& iter); +void quantize_tensor_per_tensor_affine_stub_openreg( + const at::Tensor& rtensor, + at::Tensor& qtensor, + double scale, + int64_t zero_point); +} // namespace at::native + +namespace at::native { +at::Tensor custom_autograd_fn_returns_self(at::Tensor x); +at::Tensor custom_autograd_fn_aliasing(at::Tensor x); +} // namespace at::native diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Minimal.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Minimal.cpp new file mode 100644 index 000000000000..973869087a2e --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Minimal.cpp @@ -0,0 +1,173 @@ +#include "Minimal.h" + +namespace at::native { + +at::Tensor empty_memory_format_openreg( + c10::IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt) { + const auto device = c10::device_or_default(device_opt); + const auto dtype = c10::dtype_or_default(dtype_opt); + TORCH_CHECK(device.is_privateuseone()); + TORCH_CHECK( + c10::layout_or_default(layout_opt) == c10::Layout::Strided, + "Non strided layout not supported"); + TORCH_CHECK( + !c10::pinned_memory_or_default(pin_memory_opt), + "Pin memory can only be on CPU"); + const c10::DeviceGuard device_guard(device); + constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1); + auto allocator = at::GetAllocator(at::kPrivateUse1); + return at::detail::empty_generic( + size, allocator, pu1_dks, dtype, memory_format_opt); +} + +at::Tensor empty_strided_openreg( + c10::IntArrayRef size, + c10::IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt) { + const auto device = c10::device_or_default(device_opt); + const auto dtype = c10::dtype_or_default(dtype_opt); + TORCH_CHECK(device.is_privateuseone()); + TORCH_CHECK( + c10::layout_or_default(layout_opt) == c10::Layout::Strided, + "Non strided layout not supported"); + TORCH_CHECK( + !c10::pinned_memory_or_default(pin_memory_opt), + "Pin memory can only be on CPU"); + const c10::DeviceGuard device_guard(device); + constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1); + auto allocator = at::GetAllocator(at::kPrivateUse1); + return at::detail::empty_strided_generic( + size, stride, allocator, pu1_dks, dtype); +} + +at::Tensor as_strided_openreg( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride, + std::optional storage_offset) { + MemoryGuard guard(self); + + return at::cpu::as_strided_symint(self, size, stride, storage_offset); +} + +const at::Tensor& resize_openreg_( + const at::Tensor& self, + c10::SymIntArrayRef size, + ::std::optional memory_format) { + return at::native::resize_( + self, C10_AS_INTARRAYREF_SLOW(size), memory_format); +} + +at::Tensor _reshape_alias_openreg( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride) { + return at::native::_reshape_alias( + self, C10_AS_INTARRAYREF_SLOW(size), C10_AS_INTARRAYREF_SLOW(stride)); +} + +at::Tensor _copy_from_openreg( + const at::Tensor& self, + const at::Tensor& dst, + bool non_blocking) { + TORCH_CHECK(self.defined(), "Source tensor (self) is not defined."); + TORCH_CHECK(dst.defined(), "Destination tensor (dst) is not defined."); + + MemoryGuard guard(self, dst); + + if (self.device() == dst.device()) { + at::Tensor dst_as_cpu = at::from_blob( + dst.data_ptr(), + dst.sizes(), + dst.strides(), + dst.options().device(at::kCPU)); + const at::Tensor self_as_cpu = at::from_blob( + self.data_ptr(), + self.sizes(), + self.strides(), + self.options().device(at::kCPU)); + + at::native::copy_( + const_cast(dst_as_cpu), self_as_cpu, non_blocking); + + } else { + if (self.is_cpu()) { + at::Tensor dst_as_cpu = at::from_blob( + dst.data_ptr(), + dst.sizes(), + dst.strides(), + dst.options().device(at::kCPU)); + + at::native::copy_( + const_cast(dst_as_cpu), self, non_blocking); + + } else { + at::Tensor self_as_cpu = at::from_blob( + self.data_ptr(), + self.sizes(), + self.strides(), + self.options().device(at::kCPU)); + + at::native::copy_( + const_cast(dst), self_as_cpu, non_blocking); + } + } + + return dst; +} + +at::Tensor _copy_from_and_resize_openreg( + const at::Tensor& self, + const at::Tensor& dst) { + at::native::resize_(dst, self.sizes(), std::nullopt); + + MemoryGuard guard(self, dst); + + return at::native::copy_(const_cast(dst), self, false); +} + +at::Scalar _local_scalar_dense_openreg(const at::Tensor& self) { + MemoryGuard guard(self); + return at::native::_local_scalar_dense_cpu(self); +} + +at::Tensor& set_source_Tensor_openreg_( + at::Tensor& self, + const at::Tensor& source) { + return at::native::set_tensor_(self, source); +} + +at::Tensor& set_source_Storage_openreg_(at::Tensor& self, at::Storage source) { + return at::native::set_(self, source); +} + +at::Tensor& set_source_Storage_storage_offset_openreg_( + at::Tensor& result, + at::Storage storage, + int64_t storage_offset, + c10::IntArrayRef size, + c10::IntArrayRef stride) { + // call native:: + return at::cpu::set_(result, storage, storage_offset, size, stride); +} + +at::Tensor view_openreg(const at::Tensor& self, c10::SymIntArrayRef size) { + MemoryGuard guard(self); + return at::native::view(self, C10_AS_INTARRAYREF_SLOW(size)); +} + +void cpu_fallback_openreg( + const c10::OperatorHandle& op, + torch::jit::Stack* stack) { + at::native::cpu_fallback(op, stack); +} + +} // namespace at::native diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Minimal.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Minimal.h new file mode 100644 index 000000000000..3d144f2debea --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/native/Minimal.h @@ -0,0 +1,67 @@ +#include "Common.h" + +namespace at::native { + +at::Tensor empty_memory_format_openreg( + c10::IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt); + +at::Tensor empty_strided_openreg( + c10::IntArrayRef size, + c10::IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt); + +at::Tensor as_strided_openreg( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride, + std::optional storage_offset); + +const at::Tensor& resize_openreg_( + const at::Tensor& self, + c10::SymIntArrayRef size, + ::std::optional memory_format); + +at::Tensor _reshape_alias_openreg( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride); + +at::Tensor _copy_from_openreg( + const at::Tensor& self, + const at::Tensor& dst, + bool non_blocking); + +at::Tensor _copy_from_and_resize_openreg( + const at::Tensor& self, + const at::Tensor& dst); + +at::Scalar _local_scalar_dense_openreg(const at::Tensor& self); + +at::Tensor& set_source_Tensor_openreg_( + at::Tensor& self, + const at::Tensor& source); + +at::Tensor& set_source_Storage_openreg_(at::Tensor& self, at::Storage source); + +at::Tensor& set_source_Storage_storage_offset_openreg_( + at::Tensor& result, + at::Storage storage, + int64_t storage_offset, + c10::IntArrayRef size, + c10::IntArrayRef stride); + +at::Tensor view_openreg(const at::Tensor& self, c10::SymIntArrayRef size); + +void cpu_fallback_openreg( + const c10::OperatorHandle& op, + torch::jit::Stack* stack); + +} // namespace at::native diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp new file mode 100644 index 000000000000..3d35b677cd20 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp @@ -0,0 +1,8 @@ +#include "OpenRegDeviceAllocator.h" + +namespace c10::openreg { + +static OpenRegDeviceAllocator global_openreg_alloc; +REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_openreg_alloc); + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.h new file mode 100644 index 000000000000..c9aea4a91342 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.h @@ -0,0 +1,43 @@ +#include + +#include +#include + +#include + +namespace c10::openreg { +struct OpenRegDeviceAllocator final : at::Allocator { + OpenRegDeviceAllocator() = default; + + static void ReportAndDelete(void* ptr) { + if (!ptr) { + return; + } + orFreeHost(ptr); + } + + at::DataPtr allocate(size_t nbytes) override { + int current_device_index = -1; + orGetDevice(¤t_device_index); + + auto curr_device = + c10::Device(c10::DeviceType::PrivateUse1, current_device_index); + void* data = nullptr; + if (nbytes > 0) { + orMalloc(&data, nbytes); + TORCH_CHECK( + data, "Failed to allocator ", nbytes, " bytes on openreg device."); + } + return {data, data, &ReportAndDelete, curr_device}; + } + + at::DeleterFnPtr raw_deleter() const override { + return &ReportAndDelete; + } + + void copy_data(void* dest, const void* src, std::size_t count) const final { + orMemcpy(dest, src, count, orMemcpyDeviceToDevice); + } +}; + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp new file mode 100644 index 000000000000..240c2d8ce1aa --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp @@ -0,0 +1,73 @@ +#include + +#include "OpenRegFunctions.h" + +namespace c10::openreg { + +orError_t GetDeviceCount(int* dev_count) { + return orGetDeviceCount(dev_count); +} + +orError_t GetDevice(c10::DeviceIndex* device) { + int tmp_device = -1; + auto err = orGetDevice(&tmp_device); + *device = static_cast(tmp_device); + return err; +} + +orError_t SetDevice(c10::DeviceIndex device) { + int cur_device = -1; + orGetDevice(&cur_device); + if (device == cur_device) { + return orSuccess; + } + return orSetDevice(device); +} + +int device_count_impl() { + int count = 0; + GetDeviceCount(&count); + return count; +} + +c10::DeviceIndex device_count() noexcept { + // initialize number of devices only once + static int count = []() { + try { + auto result = device_count_impl(); + TORCH_INTERNAL_ASSERT( + result <= std::numeric_limits::max(), + "Too many devices, DeviceIndex overflowed"); + return result; + } catch (const c10::Error& ex) { + // We don't want to fail, but still log the warning + // msg() returns the message without the stack trace + TORCH_WARN("Device initialization: ", ex.msg()); + return 0; + } + }(); + return static_cast(count); +} + +c10::DeviceIndex current_device() { + c10::DeviceIndex cur_device = -1; + GetDevice(&cur_device); + return cur_device; +} + +void set_device(c10::DeviceIndex device) { + SetDevice(device); +} + +DeviceIndex ExchangeDevice(DeviceIndex device) { + int current_device = -1; + orGetDevice(¤t_device); + + if (device != current_device) { + orSetDevice(device); + } + + return current_device; +} + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.h new file mode 100644 index 000000000000..b6b991ff6d3a --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include + +#include + +namespace c10::openreg { + +c10::DeviceIndex device_count() noexcept; +DeviceIndex current_device(); +void set_device(c10::DeviceIndex device); + +DeviceIndex ExchangeDevice(DeviceIndex device); + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.cpp new file mode 100644 index 000000000000..c2e03f66adc4 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.cpp @@ -0,0 +1,28 @@ +#include "OpenRegGenerator.h" + +// Default, global generators, one per device. +static std::vector default_generators; + +namespace c10::openreg { + +const at::Generator& getDefaultOpenRegGenerator(c10::DeviceIndex device_index) { + static bool flag [[maybe_unused]] = []() { + auto deivce_nums = device_count(); + default_generators.resize(deivce_nums); + for (auto i = 0; i < deivce_nums; i++) { + default_generators[i] = at::make_generator(i); + default_generators[i].seed(); + } + return true; + }(); + + c10::DeviceIndex idx = device_index; + if (idx == -1) { + idx = current_device(); + } else { + TORCH_CHECK(idx >= 0 && idx < device_count()); + } + return default_generators[idx]; +} + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.h new file mode 100644 index 000000000000..877a9707306f --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGenerator.h @@ -0,0 +1,21 @@ +#include +#include + +#include + +#include "OpenRegFunctions.h" + +namespace c10::openreg { +class OpenRegGeneratorImpl : public at::CPUGeneratorImpl { + public: + OpenRegGeneratorImpl(c10::DeviceIndex device_index) { + device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index); + key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1); + } + ~OpenRegGeneratorImpl() override = default; +}; + +const at::Generator& getDefaultOpenRegGenerator( + c10::DeviceIndex device_index = -1); + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.cpp new file mode 100644 index 000000000000..d50e56e40942 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.cpp @@ -0,0 +1,7 @@ +#include "OpenRegGuard.h" + +namespace c10::openreg { + +C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl); + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h new file mode 100644 index 000000000000..f0150fe680fb --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h @@ -0,0 +1,197 @@ +#include +#include + +#include + +#include "OpenRegFunctions.h" + +namespace c10::openreg { + +// Device guard registration +struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface { + static constexpr c10::DeviceType static_type = c10::DeviceType::PrivateUse1; + + OpenRegGuardImpl() = default; + explicit OpenRegGuardImpl(c10::DeviceType t) { + TORCH_INTERNAL_ASSERT(t == static_type); + } + + /** + * Return the type of device managed by this guard implementation. + */ + c10::DeviceType type() const override { + return static_type; + } + + /** + * Set the current device to Device, and return the previous c10::Device. + */ + c10::Device exchangeDevice(c10::Device d) const override { + TORCH_CHECK(d.is_privateuseone()); + + auto old_device_index = ExchangeDevice(d.index()); + return c10::Device(static_type, old_device_index); + } + + /** + * Get the current device. + */ + c10::Device getDevice() const override { + int device_index = current_device(); + return c10::Device(static_type, device_index); + } + + /** + * Set the current device to c10::Device. + */ + void setDevice(c10::Device d) const override { + TORCH_CHECK(d.is_privateuseone()); + + set_device(d.index()); + } + + /** + * Set the current device to c10::Device, without checking for errors + * (so, e.g., this can be called from a destructor). + */ + void uncheckedSetDevice(c10::Device d) const noexcept override { + TORCH_CHECK(d.is_privateuseone()); + + set_device(d.index()); + } + + /** + * Get the current stream for a given device. + */ + c10::Stream getStream(c10::Device d) const noexcept override { + return c10::Stream(c10::Stream::DEFAULT, d); + } + + /** + * Get the default stream for a given device. + */ + c10::Stream getDefaultStream(c10::Device d) const override { + return c10::Stream(c10::Stream::DEFAULT, d); + } + + /** + * Get a stream from the global pool for a given device. + */ + c10::Stream getStreamFromGlobalPool( + c10::Device d, + bool isHighPriority = false) const override { + return c10::Stream(c10::Stream::DEFAULT, d); + } + + /** + * Return a new stream for a given device and priority. The stream will be + * copied and shared around, device backend should be able to correctly handle + * the lifetime of the stream. + */ + c10::Stream getNewStream(c10::Device d, int priority = 0) const override { + return c10::Stream(c10::Stream::DEFAULT, d); + } + + /** + * Set a stream to be the thread local current stream for its device. + * Return the previous stream for that device. You are NOT required + * to set the current device to match the device of this stream. + */ + c10::Stream exchangeStream(c10::Stream s) const noexcept override { + return s; + } + + /** + * Destroys the given event. + */ + void destroyEvent(void* event, const c10::DeviceIndex device_index) + const noexcept override {} + + /** + * Increments the event's version and enqueues a job with this version + * in the stream's work queue. When the stream process that job + * it notifies all streams waiting on / blocked by that version of the + * event to continue and marks that version as recorded. + * */ + void record( + void** event, + const c10::Stream& stream, + const c10::DeviceIndex device_index, + const c10::EventFlag flag) const override { + static int event_id = 1; + + if (!*event) + *event = reinterpret_cast(event_id++); + } + + /** + * Does nothing if the event has not been scheduled to be recorded. + * If the event was previously enqueued to be recorded, a command + * to wait for the version of the event that exists at the time of this call + * is inserted in the stream's work queue. + * When the stream reaches this command it will stop processing + * additional commands until that version of the event is marked as recorded. + */ + void block(void* event, const c10::Stream& stream) const override {} + + /** + * Returns true if (and only if) + * (1) the event has never been scheduled to be recorded + * (2) the current version is marked as recorded. + * Returns false otherwise. + */ + bool queryEvent(void* event) const override { + return true; + } + + /** + * Get the number of devices. WARNING: This is REQUIRED to not raise + * an exception. If there is some sort of problem, e.g., driver error, + * you should report that there are zero available devices. + */ + c10::DeviceIndex deviceCount() const noexcept override { + int device_index = -1; + orGetDeviceCount(&device_index); + return device_index; + } + /** + * Return true if all the work previously enqueued on the stream for + * asynchronous execution has completed running on the device. + */ + bool queryStream(const c10::Stream& stream) const override { + return true; + } + + /** + * Wait (by blocking the calling thread) until all the work previously + * enqueued on the stream has completed running on the device. + */ + void synchronizeStream(const c10::Stream& stream) const override {} + + /** + * Wait (by blocking the calling thread) until all the work previously + * recorded on the event has completed running on the device. + */ + void synchronizeEvent(void* event) const override {} + + /** + * Ensure the caching allocator (if any) is aware that the given DataPtr is + * being used on the given stream, and that it should thus avoid recycling the + * DataPtr until all work on that stream is done. + */ + void recordDataPtrOnStream( + const c10::DataPtr& data_ptr, + const c10::Stream& stream) const override {} + + /** + * Fetch the elapsed time between two recorded events. + */ + double elapsedTime( + void* event1, + void* event2, + const c10::DeviceIndex device_index) const override { + return 1; + } +}; + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.cpp new file mode 100644 index 000000000000..57bc2d9f0d1b --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.cpp @@ -0,0 +1,11 @@ +#include "OpenRegHooks.h" + +namespace c10::openreg { + +static bool register_hook_flag [[maybe_unused]] = []() { + at::RegisterPrivateUse1HooksInterface(new OpenRegHooksInterface()); + + return true; +}(); + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.h new file mode 100644 index 000000000000..656fba8eae48 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHooks.h @@ -0,0 +1,41 @@ +#include +#include + +#include +#include + +#include + +#include "OpenRegGenerator.h" + +namespace c10::openreg { +struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface { + OpenRegHooksInterface() {}; + ~OpenRegHooksInterface() override = default; + + bool hasPrimaryContext(c10::DeviceIndex device_index) const override { + return true; + } + + at::Allocator* getPinnedMemoryAllocator() const override { + return at::getHostAllocator(at::kPrivateUse1); + } + + bool isPinnedPtr(const void* data) const override { + orPointerAttributes attr{}; + orPointerGetAttributes(&attr, data); + + return attr.type == orMemoryTypeHost; + } + + const at::Generator& getDefaultGenerator( + c10::DeviceIndex device_index) const override { + return getDefaultOpenRegGenerator(device_index); + } + + at::Generator getNewGenerator(c10::DeviceIndex device_index) const override { + return at::make_generator(device_index); + } +}; + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHostAllocator.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHostAllocator.cpp new file mode 100644 index 000000000000..552638035c38 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHostAllocator.cpp @@ -0,0 +1,8 @@ +#include "OpenRegHostAllocator.h" + +namespace c10::openreg { + +OpenRegHostAllocator caching_host_allocator; +REGISTER_HOST_ALLOCATOR(at::kPrivateUse1, &caching_host_allocator); + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHostAllocator.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHostAllocator.h new file mode 100644 index 000000000000..edef545a2783 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegHostAllocator.h @@ -0,0 +1,48 @@ +#include + +#include +#include + +#include + +namespace c10::openreg { +struct OpenRegHostAllocator final : at::HostAllocator { + OpenRegHostAllocator() = default; + + static void ReportAndDelete(void* ptr) { + if (!ptr) { + return; + } + orFreeHost(ptr); + } + + at::DataPtr allocate(size_t nbytes) override { + void* data = nullptr; + if (nbytes > 0) { + orMallocHost(&data, nbytes); + TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on host."); + } + return {data, data, &ReportAndDelete, at::Device(at::kCPU)}; + } + + at::DeleterFnPtr raw_deleter() const override { + return &ReportAndDelete; + } + + void copy_data(void* dest, const void* src, std::size_t count) const final { + orMemcpy(dest, src, count, orMemcpyHostToHost); + } + + // ignore + bool record_event(void* ptr, void* ctx, c10::Stream stream) override { + return true; + } + void empty_cache() override {} + at::HostStats get_stats() override { + return at::HostStats(); + } + void reset_accumulated_stats() override {} + void reset_peak_stats() override {} +}; + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegSerialization.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegSerialization.cpp new file mode 100644 index 000000000000..43809d60604f --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegSerialization.cpp @@ -0,0 +1,48 @@ +#include "OpenRegSerialization.h" + +namespace c10::openreg { +struct OpenRegBackendMeta : public c10::BackendMeta { + OpenRegBackendMeta(int version_number, int format_number) + : version_number_(version_number), format_number_(format_number) {} + + int version_number_{-1}; + int format_number_{-1}; +}; + +void for_serialization( + const at::Tensor& t, + std::unordered_map& m) { + auto meta_ptr = t.unsafeGetTensorImpl()->get_backend_meta(); + + if (meta_ptr != nullptr) { + auto o_meta_ptr = dynamic_cast(meta_ptr); + if (o_meta_ptr->version_number_ == 1) { + m["version_number"] = true; + } + if (o_meta_ptr->format_number_ == 29) { + m["format_number"] = true; + } + } +} + +void for_deserialization( + const at::Tensor& t, + std::unordered_map& m) { + int version_number{-1}; + int format_number{-1}; + + if (m.find("version_number") != m.end()) { + version_number = 1; + } + if (m.find("format_number") != m.end()) { + format_number = 29; + } + + c10::intrusive_ptr meta{std::unique_ptr( + new OpenRegBackendMeta(version_number, format_number))}; + t.unsafeGetTensorImpl()->set_backend_meta(meta); +} + +REGISTER_PRIVATEUSE1_SERIALIZATION(&for_serialization, &for_deserialization) + +} // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegSerialization.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegSerialization.h new file mode 100644 index 000000000000..559e92ea82f7 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegSerialization.h @@ -0,0 +1,10 @@ +#include + +#define REGISTER_PRIVATEUSE1_SERIALIZATION( \ + FOR_SERIALIZATION, FOR_DESERIALIZATION) \ + static int register_serialization() { \ + torch::jit::TensorBackendMetaRegistry( \ + c10::DeviceType::PrivateUse1, FOR_SERIALIZATION, FOR_DESERIALIZATION); \ + return 0; \ + } \ + static const int _temp = register_serialization(); diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/requirements.txt b/test/cpp_extensions/open_registration_extension/torch_openreg/requirements.txt new file mode 100644 index 000000000000..42d5e8d799f4 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/requirements.txt @@ -0,0 +1,2 @@ +torch +pybind11 diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py b/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py new file mode 100644 index 000000000000..38a866e4ce21 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py @@ -0,0 +1,102 @@ +import multiprocessing +import os +import shutil +import subprocess +import sys +import sysconfig +from distutils.command.clean import clean + +from setuptools import Extension, find_packages, setup + + +PACKAGE_NAME = "torch_openreg" +BASE_DIR = os.path.dirname(os.path.realpath(__file__)) + + +def get_pytorch_dir(): + import torch + + return os.path.dirname(os.path.realpath(torch.__file__)) + + +def build_deps(): + build_dir = os.path.join(BASE_DIR, "build") + os.makedirs(build_dir, exist_ok=True) + + cmake_args = [ + "-DCMAKE_INSTALL_PREFIX=" + + os.path.realpath(os.path.join(BASE_DIR, "torch_openreg")), + "-DPYTHON_INCLUDE_DIR=" + sysconfig.get_paths().get("include"), + "-DPYTORCH_INSTALL_DIR=" + get_pytorch_dir(), + ] + + subprocess.check_call( + ["cmake", BASE_DIR] + cmake_args, cwd=build_dir, env=os.environ + ) + + build_args = [ + "--build", + ".", + "--target", + "install", + "--", + ] + build_args += ["-j", str(multiprocessing.cpu_count())] + + command = ["cmake"] + build_args + subprocess.check_call(command, cwd=build_dir, env=os.environ) + + +class BuildClean(clean): + def run(self): + for i in ["build", "install", "torch_openreg.egg-info", "torch_openreg/lib"]: + dirs = os.path.join(BASE_DIR, i) + if os.path.exists(dirs) and os.path.isdir(dirs): + shutil.rmtree(dirs) + + for dirpath, _, filenames in os.walk(os.path.join(BASE_DIR, "torch_openreg")): + for filename in filenames: + if filename.endswith(".so"): + os.remove(os.path.join(dirpath, filename)) + + +RUN_BUILD_DEPS = any(arg == "clean" for arg in sys.argv) + + +def main(): + if not RUN_BUILD_DEPS: + build_deps() + + ext_modules = [ + Extension( + name="torch_openreg._C", + sources=["torch_openreg/csrc/stub.c"], + extra_compile_args=["-g", "-Wall", "-Werror"], + libraries=["torch_bindings"], + library_dirs=[os.path.join(BASE_DIR, "torch_openreg/lib")], + extra_link_args=["-Wl,-rpath,$ORIGIN/lib"], + ) + ] + + package_data = {PACKAGE_NAME: ["lib/*.so*"]} + + setup( + name=PACKAGE_NAME, + version="0.0.1", + author="PyTorch Core Team", + description="Example for PyTorch out of tree registration", + packages=find_packages(exclude=("test",)), + package_data=package_data, + install_requires=[ + "torch", + ], + ext_modules=ext_modules, + python_requires=">=3.8", + cmdclass={ + "clean": BuildClean, # type: ignore[misc] + }, + ) + + +if __name__ == "__main__": + main() diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/CMakeLists.txt b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/CMakeLists.txt new file mode 100644 index 000000000000..7fec109eeb1c --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/CMakeLists.txt @@ -0,0 +1,11 @@ +set(LIBRARY_NAME openreg) + +file(GLOB_RECURSE SOURCE_FILES + "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" +) + +add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES}) + +target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + +install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/README.md b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/README.md new file mode 100644 index 000000000000..af17ef3abdb1 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/README.md @@ -0,0 +1,137 @@ +# OpenReg: An Accelerator Backend that Simulates CUDA Behavior on a CPU + +## Introduction + +OpenReg is a C++ backend library that simulates the behavior of a CUDA-like device on a CPU. Its core objective is **not to accelerate computation or improve performance**, but rather to **simulate modern CUDA programming, enabling developers to prototype and test in an environment without actual GPU hardware**. The current design principles are as follows: + +* **API Consistency**: Provide an interface consistent with the CUDA Runtime API, allowing upper-level applications (like PyTorch's PrivateUse1 backend) to switch and test seamlessly. +* **Functional Consistency**: Provide behavior consistent with the CUDA Runtime, such as memory isolation, device context management, etc. +* **Completeness**: Aim to support PrivateUse1 device integration and safeguard the third-party device integration mechanism, without striving to cover all capabilities of the CUDA Runtime. + +## Directory Structure + +The project's code is organized with a clear structure and separation of responsibilities: + +```text +openreg/ +├── CMakeLists.txt # Top-level CMake build script, used to compile and generate libopenreg.so +├── include/ +│ └── openreg.h # Public API header file, external users only need to include this file +└── csrc/ + ├── device.cpp # Implementation of device management-related APIs + └── memory.cpp # Implementation of APIs for memory management, copying, and protection +``` + +* `include/openreg.h`: Defines all externally exposed C-style APIs, data structures, and enums. It is the "public face" of this library. +* `csrc/`: Contains the C++ implementation source code for all core functionalities. + * `device.cpp`: Implements device discovery (`orGetDeviceCount`) and thread context management (`orSetDevice`/`orGetDevice`). + * `memory.cpp`: Implements the core functions of memory allocation (`orMalloc`/`orMallocHost`), deallocation, copying, and memory protection (`orMemoryProtect`, `orMemoryUnprotect`). +* `CMakeLists.txt`: Responsible for compiling and linking all source files under the `csrc/` directory to generate the final `libopenreg.so` shared library. + +## Implemented APIs + +OpenReg currently provides a set of APIs covering basic memory and device management. + +### Device Management APIs + +| OpenReg | CUDA | Feature Description | +| :------------------- | :------------------- | :------------------------------------------------ | +| `orGetDeviceCount` | `cudaGetDeviceCount` | Get the number of devices | +| `orSetDevice` | `cudaSetDevice` | Set the current device for the current thread | +| `orGetDevice` | `cudaGetDevice` | Get the current device for the current thread | + +### Memory Management APIs + +| OpenReg | CUDA | Feature Description | +| :----------------------- | :--------------------------- | :----------------------------------------- | +| `orMalloc` | `cudaMalloc` | Allocate device memory | +| `orFree` | `cudaFree` | Free device memory | +| `orMallocHost` | `cudaMallocHost` | Allocate page-locked (Pinned) host memory | +| `orFreeHost` | `cudaFreeHost` | Free page-locked host memory | +| `orMemcpy` | `cudaMemcpy` | Synchronous memory copy | +| `orMemcpyAsync` | `cudaMemcpyAsync` | Asynchronous memory copy | +| `orPointerGetAttributes` | `cudaPointerGetAttributes` | Get pointer attributes | +| `orMemoryUnprotect` | - | (Internal use) Unprotect memory | +| `orMemoryProtect` | - | (Internal use) Restore memory protection | + +## Implementation Principles + +### Device Management Principles + +Simulating multiple devices and thread-safe device context switching: + +1. **Device Count**: The total number of simulated devices is defined by the compile-time constant `constexpr int kDeviceCount`. +2. **Device Switching**: Device switching in multi-threaded scenarios is simulated using a **TLS (Thread-Local Storage) global variable**. + +### Memory Management Principles + +Simulating device memory, host memory, and memory copies: + +1. **Allocation**: A page-aligned memory block is allocated using `mmap` + `mprotect` with the permission flag `PROT_NONE`. Read, write, and execute operations on this memory region are all prohibited. +2. **Deallocation**: Memory is freed using `munmap`. +3. **Authorization**: When a legitimate memory access is required, an RAII guard restores the memory permissions to `PROT_READ | PROT_WRITE`. The permissions are automatically reverted to `PROT_NONE` when the scope is exited. + +## Usage Example + +The following is a simple code snippet demonstrating how to use the core features of the OpenReg library. + +```cpp +#include "openreg.h" +#include +#include +#include + +#define OR_CHECK(call) do { \ + orError_t err = call; \ + if (err != orSuccess) { \ + fprintf(stderr, "OR Error code %d in %s at line %d\n", err, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ + } \ +} while (0) + +int main() { + int device_count = 0; + OR_CHECK(orGetDeviceCount(&device_count)); + std::cout << "Found " << device_count << " simulated devices." << std::endl; + + int current_device = -1; + OR_CHECK(orSetDevice(1)); + OR_CHECK(orGetDevice(¤t_device)); + std::cout << "Set current device to " << current_device << "." << std::endl; + + const int n = 1024; + const size_t size = n * sizeof(int); + int *h_a, *d_a; + OR_CHECK(orMallocHost((void**)&h_a, size)); + OR_CHECK(orMalloc((void**)&d_a, size)); + + orPointerAttributes attr; + OR_CHECK(orPointerGetAttributes(&attr, d_a)); + std::cout << "Pointer " << (void*)d_a << " is of type " << attr.type + << " on device " << attr.device << std::endl; + + for (int i = 0; i < n; ++i) { + h_a[i] = i; + } + OR_CHECK(orMemcpy(d_a, h_a, size, orMemcpyHostToDevice)); + std::cout << "Data copied from Host to Device." << std::endl; + + // std::cout << "Trying to access device memory directly from CPU..." << std::endl; + // int val = d_a[0]; // CRASH! + + // Clean up resources + OR_CHECK(orFree(d_a)); + OR_CHECK(orFreeHost(h_a)); + std::cout << "Resources freed." << std::endl; + + return 0; +} +``` + +## Next Steps + +To better support PrivateUse1 device integration, the following capabilities are planned for the future: + +* **Stream Support**: Provide the ability to simulate CUDA Streams. +* **Event Support**: Provide the ability to simulate CUDA Events. +* **Cross-Platform Support**: Add support for Windows and macOS (low priority). diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/device.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/device.cpp new file mode 100644 index 000000000000..3f1d43ea0b55 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/device.cpp @@ -0,0 +1,35 @@ +#include + +namespace { +// Total device numbers +constexpr int DEVICE_COUNT = 2; +// Current device index +thread_local int gCurrentDevice = 0; +} // namespace + +orError_t orGetDeviceCount(int* count) { + if (!count) { + return orErrorUnknown; + } + + *count = DEVICE_COUNT; + return orSuccess; +} + +orError_t orGetDevice(int* device) { + if (!device) { + return orErrorUnknown; + } + + *device = gCurrentDevice; + return orSuccess; +} + +orError_t orSetDevice(int device) { + if (device < 0 || device >= DEVICE_COUNT) { + return orErrorUnknown; + } + + gCurrentDevice = device; + return orSuccess; +} diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/memory.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/memory.cpp new file mode 100644 index 000000000000..762cd96d23bb --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/memory.cpp @@ -0,0 +1,249 @@ +#include + +#include +#include +#include +#include +#include +#include + +namespace openreg { +namespace internal { + +class ScopedMemoryProtector { + public: + ScopedMemoryProtector(const orPointerAttributes& info) + : m_info(info), m_protected(false) { + if (m_info.type == orMemoryType::orMemoryTypeDevice) { + if (mprotect(m_info.pointer, m_info.size, PROT_READ | PROT_WRITE) == + 0) { + m_protected = true; + } + } + } + ~ScopedMemoryProtector() { + if (m_protected) { + mprotect(m_info.pointer, m_info.size, PROT_NONE); + } + } + ScopedMemoryProtector(const ScopedMemoryProtector&) = delete; + ScopedMemoryProtector& operator=(const ScopedMemoryProtector&) = delete; + + private: + orPointerAttributes m_info; + bool m_protected; +}; + +class MemoryManager { + public: + static MemoryManager& getInstance() { + static MemoryManager instance; + return instance; + } + + orError_t allocate(void** ptr, size_t size, orMemoryType type) { + if (!ptr || size == 0) + return orErrorUnknown; + + std::lock_guard lock(m_mutex); + long page_size = sysconf(_SC_PAGESIZE); + size_t aligned_size = ((size - 1) / page_size + 1) * page_size; + void* mem = nullptr; + int current_device = -1; + + if (type == orMemoryType::orMemoryTypeDevice) { + orGetDevice(¤t_device); + + mem = mmap( + nullptr, + aligned_size, + PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, + -1, + 0); + if (mem == MAP_FAILED) + return orErrorUnknown; + if (mprotect(mem, aligned_size, PROT_NONE) != 0) { + munmap(mem, aligned_size); + return orErrorUnknown; + } + } else { + if (posix_memalign(&mem, page_size, aligned_size) != 0) { + return orErrorUnknown; + } + } + + m_registry[mem] = {type, current_device, mem, aligned_size}; + *ptr = mem; + return orSuccess; + } + + orError_t free(void* ptr) { + if (!ptr) + return orSuccess; + + std::lock_guard lock(m_mutex); + auto it = m_registry.find(ptr); + if (it == m_registry.end()) + return orErrorUnknown; + const auto& info = it->second; + if (info.type == orMemoryType::orMemoryTypeDevice) { + mprotect(info.pointer, info.size, PROT_READ | PROT_WRITE); + munmap(info.pointer, info.size); + } else { + ::free(info.pointer); + } + m_registry.erase(it); + return orSuccess; + } + + orError_t memcpy( + void* dst, + const void* src, + size_t count, + orMemcpyKind kind) { + if (!dst || !src || count == 0) + return orErrorUnknown; + std::lock_guard lock(m_mutex); + orPointerAttributes dst_info = getPointerInfo(dst); + orPointerAttributes src_info = getPointerInfo(src); + switch (kind) { + case orMemcpyHostToDevice: + if (dst_info.type != orMemoryType::orMemoryTypeDevice || + src_info.type == orMemoryType::orMemoryTypeDevice) + return orErrorUnknown; + break; + case orMemcpyDeviceToHost: + if (dst_info.type == orMemoryType::orMemoryTypeDevice || + src_info.type != orMemoryType::orMemoryTypeDevice) + return orErrorUnknown; + break; + case orMemcpyDeviceToDevice: + if (dst_info.type != orMemoryType::orMemoryTypeDevice || + src_info.type != orMemoryType::orMemoryTypeDevice) + return orErrorUnknown; + break; + case orMemcpyHostToHost: + if (dst_info.type == orMemoryType::orMemoryTypeDevice || + src_info.type == orMemoryType::orMemoryTypeDevice) + return orErrorUnknown; + break; + } + { + ScopedMemoryProtector dst_protector(dst_info); + ScopedMemoryProtector src_protector(src_info); + ::memcpy(dst, src, count); + } + + return orSuccess; + } + + orError_t getPointerAttributes( + orPointerAttributes* attributes, + const void* ptr) { + if (!attributes || !ptr) + return orErrorUnknown; + + std ::lock_guard lock(m_mutex); + orPointerAttributes info = getPointerInfo(ptr); + + attributes->type = info.type; + if (info.type == orMemoryType::orMemoryTypeUnmanaged) { + attributes->device = -1; + attributes->pointer = const_cast(ptr); + attributes->size = 0; + } else { + attributes->device = info.device; + attributes->pointer = info.pointer; + attributes->size = info.size; + } + + return orSuccess; + } + + orError_t unprotect(void* ptr) { + std::lock_guard lock(m_mutex); + orPointerAttributes info = getPointerInfo(ptr); + if (info.type != orMemoryType::orMemoryTypeDevice) { + return orErrorUnknown; + } + if (mprotect(info.pointer, info.size, PROT_READ | PROT_WRITE) != 0) { + return orErrorUnknown; + } + return orSuccess; + } + + orError_t protect(void* ptr) { + std::lock_guard lock(m_mutex); + orPointerAttributes info = getPointerInfo(ptr); + if (info.type != orMemoryType::orMemoryTypeDevice) { + return orErrorUnknown; + } + if (mprotect(info.pointer, info.size, PROT_NONE) != 0) { + return orErrorUnknown; + } + return orSuccess; + } + + private: + MemoryManager() = default; + orPointerAttributes getPointerInfo(const void* ptr) { + auto it = m_registry.upper_bound(const_cast(ptr)); + if (it == m_registry.begin()) + return {}; + --it; + const char* p_char = static_cast(ptr); + const char* base_char = static_cast(it->first); + if (p_char >= base_char && p_char < (base_char + it->second.size)) { + return it->second; + } + return {}; + } + std::map m_registry; + std::mutex m_mutex; +}; + +} // namespace internal +} // namespace openreg + +orError_t orMalloc(void** devPtr, size_t size) { + return openreg::internal::MemoryManager::getInstance().allocate( + devPtr, size, orMemoryType::orMemoryTypeDevice); +} + +orError_t orFree(void* devPtr) { + return openreg::internal::MemoryManager::getInstance().free(devPtr); +} + +orError_t orMallocHost(void** hostPtr, size_t size) { + return openreg::internal::MemoryManager::getInstance().allocate( + hostPtr, size, orMemoryType::orMemoryTypeHost); +} + +orError_t orFreeHost(void* hostPtr) { + return openreg::internal::MemoryManager::getInstance().free(hostPtr); +} + +orError_t orMemcpy( + void* dst, + const void* src, + size_t count, + orMemcpyKind kind) { + return openreg::internal::MemoryManager::getInstance().memcpy( + dst, src, count, kind); +} + +orError_t orPointerGetAttributes( + orPointerAttributes* attributes, + const void* ptr) { + return openreg::internal::MemoryManager::getInstance().getPointerAttributes( + attributes, ptr); +} + +orError_t orMemoryUnprotect(void* devPtr) { + return openreg::internal::MemoryManager::getInstance().unprotect(devPtr); +} + +orError_t orMemoryProtect(void* devPtr) { + return openreg::internal::MemoryManager::getInstance().protect(devPtr); +} diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/include/openreg.h b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/include/openreg.h new file mode 100644 index 000000000000..b6b0b3da4295 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/include/openreg.h @@ -0,0 +1,49 @@ +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef enum orError_t { orSuccess = 0, orErrorUnknown = 1 } orError_t; + +typedef enum orMemcpyKind { + orMemcpyHostToHost = 0, + orMemcpyHostToDevice = 1, + orMemcpyDeviceToHost = 2, + orMemcpyDeviceToDevice = 3 +} orMemcpyKind; + +typedef enum orMemoryType { + orMemoryTypeUnmanaged = 0, + orMemoryTypeHost = 1, + orMemoryTypeDevice = 2 +} orMemoryType; + +struct orPointerAttributes { + orMemoryType type = orMemoryType::orMemoryTypeUnmanaged; + int device; + void* pointer; + size_t size; +}; + +orError_t orMalloc(void** devPtr, size_t size); +orError_t orFree(void* devPtr); +orError_t orMallocHost(void** hostPtr, size_t size); +orError_t orFreeHost(void* hostPtr); +orError_t orMemcpy(void* dst, const void* src, size_t count, orMemcpyKind kind); +orError_t orMemoryUnprotect(void* devPtr); +orError_t orMemoryProtect(void* devPtr); + +orError_t orGetDeviceCount(int* count); +orError_t orSetDevice(int device); +orError_t orGetDevice(int* device); + +orError_t orPointerGetAttributes( + orPointerAttributes* attributes, + const void* ptr); + +#ifdef __cplusplus +} // extern "C" +#endif diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/__init__.py b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/__init__.py new file mode 100644 index 000000000000..32bb170075ef --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/__init__.py @@ -0,0 +1,8 @@ +import torch +import torch_openreg._C # type: ignore[misc] +import torch_openreg.openreg + + +torch.utils.rename_privateuse1_backend("openreg") +torch._register_device_module("openreg", torch_openreg.openreg) +torch.utils.generate_methods_for_privateuse1_backend(for_storage=True) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/CMakeLists.txt b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/CMakeLists.txt new file mode 100644 index 000000000000..574b5b1c748a --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/CMakeLists.txt @@ -0,0 +1,12 @@ +set(LIBRARY_NAME torch_bindings) + +file(GLOB_RECURSE SOURCE_FILES + "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" +) + +add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES}) + +target_link_libraries(${LIBRARY_NAME} PRIVATE torch_python torch_openreg) +target_link_directories(${LIBRARY_NAME} PRIVATE ${PYTORCH_INSTALL_DIR}/lib) + +install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp new file mode 100644 index 000000000000..4acdbfc8e1dc --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp @@ -0,0 +1,99 @@ +#include + +#include +#include +#include +#include +#include + +#include + +static PyObject* _initExtension(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + + at::globalContext().lazyInitDevice(c10::DeviceType::PrivateUse1); + + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(arg), + "_get_default_generator expects an int, but got ", + THPUtils_typename(arg)); + auto idx = static_cast(THPUtils_unpackLong(arg)); + + return THPGenerator_initDefaultGenerator( + at::globalContext().defaultGenerator( + c10::Device(c10::DeviceType::PrivateUse1, idx))); + + END_HANDLE_TH_ERRORS +} + +PyObject* _setDevice(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice"); + auto device = THPUtils_unpackLong(arg); + + torch::utils::device_lazy_init(at::kPrivateUse1); + c10::openreg::set_device(static_cast(device)); + + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* _exchangeDevice(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice"); + auto device_index = THPUtils_unpackDeviceIndex(arg); + if (device_index < 0) { + return THPUtils_packInt32(-1); + } + + torch::utils::device_lazy_init(at::kPrivateUse1); + auto current_device = c10::openreg::ExchangeDevice(device_index); + + return THPUtils_packDeviceIndex(current_device); + END_HANDLE_TH_ERRORS +} + +PyObject* _getDevice(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + torch::utils::device_lazy_init(at::kPrivateUse1); + auto device = static_cast(c10::openreg::current_device()); + return THPUtils_packInt32(device); + END_HANDLE_TH_ERRORS +} + +PyObject* _getDeviceCount(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + return THPUtils_packUInt64(c10::openreg::device_count()); + END_HANDLE_TH_ERRORS +} + +static PyMethodDef methods[] = { + {"_init", _initExtension, METH_NOARGS, nullptr}, + {"_get_default_generator", _getDefaultGenerator, METH_O, nullptr}, + {"_get_device", _getDevice, METH_NOARGS, nullptr}, + {"_set_device", _setDevice, METH_O, nullptr}, + {"_exchangeDevice", _exchangeDevice, METH_O, nullptr}, + {"_get_device_count", _getDeviceCount, METH_NOARGS, nullptr}, + {nullptr, nullptr, 0, nullptr}}; + +/* + * When ASAN is enabled, PyTorch modifies the dlopen flag during import, + * causing all global and weak symbols in _C.so and its dependent libraries + * to be exposed to the global symbol scope, which in turn causes + * subsequent symbols with the same name in other libraries to be intercepted. + * Therefore, it cannot be named initModule here, otherwise initModule + * in torch/csrc/Module.cpp will be called, resulting in failure. + */ +extern "C" PyObject* initOpenRegModule(void) { + static struct PyModuleDef openreg_C_module = { + PyModuleDef_HEAD_INIT, "torch_openreg._C", nullptr, -1, methods}; + PyObject* mod = PyModule_Create(&openreg_C_module); + + return mod; +} diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/stub.c b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/stub.c new file mode 100644 index 000000000000..cd3eb4fe1ecc --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/stub.c @@ -0,0 +1,15 @@ +#include + +extern PyObject* initOpenRegModule(void); + +#ifndef _WIN32 +#ifdef __cplusplus +extern "C" +#endif +__attribute__((visibility("default"))) PyObject* PyInit__C(void); +#endif + +PyMODINIT_FUNC PyInit__C(void) +{ + return initOpenRegModule(); +} diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py new file mode 100644 index 000000000000..177468b8f41b --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py @@ -0,0 +1,72 @@ +import torch +import torch_openreg._C # type: ignore[misc] + + +_initialized = False + + +class device: + r"""Context-manager that changes the selected device. + + Args: + device (torch.device or int): device index to select. It's a no-op if + this argument is a negative integer or ``None``. + """ + + def __init__(self, device): + self.idx = torch.accelerator._get_device_index(device, optional=True) + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = torch_openreg._C._exchangeDevice(self.idx) + + def __exit__(self, type, value, traceback): + self.idx = torch_openreg._C._set_device(self.prev_idx) + return False + + +def is_available(): + return True + + +def device_count() -> int: + return torch_openreg._C._get_device_count() + + +def current_device(): + return torch_openreg._C._get_device() + + +def set_device(device) -> None: + return torch_openreg._C._set_device(device) + + +def is_initialized(): + return _initialized + + +def _lazy_init(): + global _initialized + if is_initialized(): + return + torch_openreg._C._init() + _initialized = True + + +from .random import * # noqa: F403 + + +__all__ = [ + "device", + "device_count", + "current_device", + "set_device", + "initial_seed", + "is_available", + "is_initialized", + "random", + "manual_seed", + "manual_seed_all", + "get_rng_state", + "set_rng_state", +] diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/random.py b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/random.py new file mode 100644 index 000000000000..5202145a5552 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/random.py @@ -0,0 +1,60 @@ +import torch +import torch_openreg._C # type: ignore[misc] + +from . import _lazy_init, current_device, device_count + + +__all__ = [ + "get_rng_state", + "set_rng_state", + "manual_seed", + "manual_seed_all", + "initial_seed", +] + + +def get_rng_state(device="openreg"): + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("openreg", device) + idx = device.index + if idx is None: + idx = current_device() + default_generator = torch_openreg._C._get_default_generator(idx) + return default_generator.get_state() + + +def set_rng_state(new_state, device="openreg"): + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("openreg", device) + idx = device.index + if idx is None: + idx = current_device() + default_generator = torch_openreg._C._get_default_generator(idx) + default_generator.set_state(new_state) + + +def initial_seed() -> int: + _lazy_init() + idx = current_device() + default_generator = torch_openreg._C._get_default_generator(idx) + return default_generator.initial_seed() + + +def manual_seed(seed: int) -> None: + seed = int(seed) + + idx = current_device() + default_generator = torch_openreg._C._get_default_generator(idx) + default_generator.manual_seed(seed) + + +def manual_seed_all(seed: int) -> None: + seed = int(seed) + + for idx in range(device_count()): + default_generator = torch_openreg._C._get_default_generator(idx) + default_generator.manual_seed(seed) diff --git a/test/distributed/_composable/test_composability/test_2d_composability.py b/test/distributed/_composable/test_composability/test_2d_composability.py index 5ad96979717c..3ab0b6269b2d 100644 --- a/test/distributed/_composable/test_composability/test_2d_composability.py +++ b/test/distributed/_composable/test_composability/test_2d_composability.py @@ -554,21 +554,6 @@ def _compare_params(self, m1, m2): p2 = p2.redistribute(p2.device_mesh, [Replicate()]).to_local() self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}") - @with_comms - @skip_if_lt_x_gpu(4) - def test_raise_invalid_tp_composition(self): - with self.assertRaisesRegex( - RuntimeError, r"Found TP device_mesh on the \d dimension of its parent mesh" - ): - mesh_2d = init_device_mesh( - self.device_type, (2, self.world_size // 2), mesh_dim_names=("tp", "dp") - ) - parallelize_plan = { - "net1": ColwiseParallel(), - "net2": RowwiseParallel(), - } - parallelize_module(SimpleModel().cuda(), mesh_2d["tp"], parallelize_plan) - @with_comms @skip_if_lt_x_gpu(4) def test_2d_fsdp_state_enable_extension(self): diff --git a/test/distributed/checkpoint/_experimental/test_builder.py b/test/distributed/checkpoint/_experimental/test_builder.py index 7eed02755610..788f78892fbb 100644 --- a/test/distributed/checkpoint/_experimental/test_builder.py +++ b/test/distributed/checkpoint/_experimental/test_builder.py @@ -123,8 +123,8 @@ def test_make_async_checkpointer(self) -> None: # Create async checkpointer using factory function with default parameters config: CheckpointerConfig = CheckpointerConfig() config.staging_config = CheckpointStagerConfig( - use_cuda_non_blocking_copy=torch.cuda.is_available(), - use_pinned_memory=torch.cuda.is_available(), + use_non_blocking_copy=torch.accelerator.is_available(), + use_pinned_memory=torch.accelerator.is_available(), ) checkpointer = make_async_checkpointer(config=config, rank_info=self.rank_info) diff --git a/test/distributed/checkpoint/_experimental/test_staging.py b/test/distributed/checkpoint/_experimental/test_staging.py index 0eeba5d63524..3fdb3bc022f2 100644 --- a/test/distributed/checkpoint/_experimental/test_staging.py +++ b/test/distributed/checkpoint/_experimental/test_staging.py @@ -74,7 +74,7 @@ def test_cuda_non_blocking_without_cuda(self) -> None: if torch.cuda.is_available(): self.skipTest("CUDA is available, cannot test CUDA unavailable scenario") - options = CheckpointStagerConfig(use_cuda_non_blocking_copy=True) + options = CheckpointStagerConfig(use_non_blocking_copy=True) with self.assertRaises(AssertionError): DefaultStager(options) @@ -86,21 +86,21 @@ def test_different_option_combinations(self) -> None: use_pinned_memory=False, use_shared_memory=False, use_async_staging=False, - use_cuda_non_blocking_copy=False, + use_non_blocking_copy=False, ), # Only pinned memory CheckpointStagerConfig( use_pinned_memory=True, use_shared_memory=False, use_async_staging=False, - use_cuda_non_blocking_copy=False, + use_non_blocking_copy=False, ), # Only shared memory CheckpointStagerConfig( use_pinned_memory=False, use_shared_memory=True, use_async_staging=False, - use_cuda_non_blocking_copy=False, + use_non_blocking_copy=False, ), ] @@ -108,19 +108,19 @@ def test_different_option_combinations(self) -> None: # Only async staging test_cases.append( CheckpointStagerConfig( - use_pinned_memory=torch.cuda.is_available(), + use_pinned_memory=torch.accelerator.is_available(), use_shared_memory=False, use_async_staging=True, - use_cuda_non_blocking_copy=False, + use_non_blocking_copy=False, ) ) # Only CUDA non-blocking copy test_cases.append( CheckpointStagerConfig( - use_pinned_memory=torch.cuda.is_available(), + use_pinned_memory=torch.accelerator.is_available(), use_shared_memory=False, use_async_staging=False, - use_cuda_non_blocking_copy=torch.cuda.is_available(), + use_non_blocking_copy=torch.accelerator.is_available(), ) ) @@ -129,7 +129,7 @@ def test_different_option_combinations(self) -> None: stager = DefaultStager(options) # Test staging works with these options - if options.use_async_staging and torch.cuda.is_available(): + if options.use_async_staging and torch.accelerator.is_available(): result = stager.stage(self.state_dict) self.assertIsInstance(result, Future) staged_dict = result.result() @@ -183,9 +183,9 @@ def test_multiple_staging_operations(self) -> None: """Test multiple staging operations with the same stager.""" options = CheckpointStagerConfig( use_async_staging=False, - use_pinned_memory=torch.cuda.is_available(), + use_pinned_memory=torch.accelerator.is_available(), use_shared_memory=False, - use_cuda_non_blocking_copy=torch.cuda.is_available(), + use_non_blocking_copy=torch.accelerator.is_available(), ) stager = DefaultStager(options) diff --git a/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py b/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py index c2e37850d9d7..e1b1041875af 100644 --- a/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py +++ b/test/distributed/checkpoint/e2e/test_e2e_save_and_load.py @@ -279,7 +279,7 @@ def _run_e2e_test( use_async_staging=zoc, use_shared_memory=use_shared_memory, use_pinned_memory=zoc, - use_cuda_non_blocking_copy=zoc, + use_non_blocking_copy=zoc, ) stager = DefaultStager(staging_options) async_save_response_or_future = saver.async_save( diff --git a/test/distributed/checkpoint/test_fsspec.py b/test/distributed/checkpoint/test_fsspec.py index af061e5b95c9..9d69d6d386a7 100644 --- a/test/distributed/checkpoint/test_fsspec.py +++ b/test/distributed/checkpoint/test_fsspec.py @@ -18,7 +18,10 @@ from torch.distributed.checkpoint.utils import CheckpointException from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType -from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu +from torch.testing._internal.common_distributed import ( + requires_accelerator_dist_backend, + skip_if_lt_x_gpu, +) from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._shard.sharded_tensor import ( ShardedTensorTestBase, @@ -26,6 +29,10 @@ ) +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" +BACKEND = torch.distributed.get_default_backend_for_device(device_type) + + def with_temp_dir( func: Optional[Callable] = None, ) -> Optional[Callable]: @@ -75,14 +82,14 @@ class TestFSSpec(ShardedTensorTestBase): def world_size(self) -> int: return 2 - @with_comms(init_rpc=False) + @with_comms(backend=BACKEND, init_rpc=False) + @requires_accelerator_dist_backend() @skip_if_lt_x_gpu(2) - @requires_nccl() @with_temp_dir def test_fsspec(self): CHECKPOINT_DIR = self.temp_dir - model = FSDP(MyTestModule().cuda()) + model = FSDP(MyTestModule().to(device_type)) optim = torch.optim.Adam(model.parameters(), lr=0.1) model(torch.rand(8, 8, device=dist.get_rank())).sum().backward() optim.step() @@ -99,7 +106,7 @@ def test_fsspec(self): planner=dcp.DefaultSavePlanner(), ) - model_2 = FSDP(MyTestModule().cuda()) + model_2 = FSDP(MyTestModule().to(device_type)) optim_2 = torch.optim.Adam(model_2.parameters(), lr=0.1) with FSDP.summon_full_params(model): @@ -149,9 +156,9 @@ def opt_at(opt, idx): opt_at(optim, 0)["exp_avg_sq"], opt_at(optim_2, 0)["exp_avg_sq"] ) - @with_comms(init_rpc=False) + @with_comms(backend=BACKEND, init_rpc=False) + @requires_accelerator_dist_backend() @skip_if_lt_x_gpu(2) - @requires_nccl() @with_temp_dir def test_overwrite(self): t1, t2 = torch.randn(10), torch.randn(10) diff --git a/test/distributed/checkpoint/test_hf_safetensor_e2e.py b/test/distributed/checkpoint/test_hf_safetensor_e2e.py index 0220ae5138fc..92f9b9723706 100644 --- a/test/distributed/checkpoint/test_hf_safetensor_e2e.py +++ b/test/distributed/checkpoint/test_hf_safetensor_e2e.py @@ -151,8 +151,6 @@ def test_consolidate_to_one_file(self) -> None: global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) checkpoint_dir = self.temp_dir - consolidated_output_dir = os.path.join(checkpoint_dir, "consolidated") - os.makedirs(consolidated_output_dir, exist_ok=True) state_dict_to_save = {"dtensor": dtensor} dist_cp.save( @@ -160,15 +158,13 @@ def test_consolidate_to_one_file(self) -> None: storage_writer=dist_cp.HuggingFaceStorageWriter( path=checkpoint_dir, save_distributed=True, - consolidated_output_path=consolidated_output_dir, + enable_consolidation=True, ), ) dist.barrier() if self.rank == 0: - file_path = os.path.join( - consolidated_output_dir, "model-00001-of-00001.safetensors" - ) + file_path = os.path.join(checkpoint_dir, "model-00001-of-00001.safetensors") loaded_dict = safetensors.torch.load_file(file_path) self.assertEqual(loaded_dict.keys(), {"dtensor"}) self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor)) diff --git a/test/distributed/checkpoint/test_pg_transport.py b/test/distributed/checkpoint/test_pg_transport.py index df64e9451b46..baa2eb54b054 100644 --- a/test/distributed/checkpoint/test_pg_transport.py +++ b/test/distributed/checkpoint/test_pg_transport.py @@ -1,13 +1,17 @@ # Owner(s): ["oncall: distributed"] import logging -import os from datetime import timedelta from typing import Optional from unittest.mock import MagicMock, patch import torch import torch.nn as nn +from torch.distributed._shard.sharded_tensor import ( + init_from_local_shards, + Shard as ShardedTensorShard, + ShardMetadata, +) from torch.distributed.checkpoint._pg_transport import ( _cast_tensor, _prepare_state_dict, @@ -34,9 +38,56 @@ logger = logging.getLogger(__name__) +def _create_sharded_tensor_state_dict( + rank: int, world_size: int, device: torch.device +) -> dict: + """ + Create state_dict with ShardedTensor for deterministic testing. + Args: + rank: Current rank + world_size: Total world size + device: Device to create tensors on + Returns: + dict: State dictionary with ShardedTensor + """ + # Create deterministic local shard for this rank + global_size = 64 + shard_size = global_size // world_size + start_idx = rank * shard_size + end_idx = (rank + 1) * shard_size + + # Create local tensor with deterministic values + local_tensor = torch.arange( + start_idx * 8, end_idx * 8, dtype=torch.float32, device=device + ).reshape(shard_size, 8) + + # Create ShardedTensor using init_from_local_shards + sharded_tensor = init_from_local_shards( + [ + ShardedTensorShard( + tensor=local_tensor, + metadata=ShardMetadata( + shard_offsets=[start_idx, 0], + shard_sizes=[shard_size, 8], + placement=f"rank:{rank}/{device}", + ), + ) + ], + global_size, + 8, + ) + + return { + "sharded_tensor": sharded_tensor, + "rank_scalar": torch.tensor(float(rank), device=device), + } + + class SimpleModel(nn.Module): - def __init__(self): + def __init__(self, seed: int = 42): super().__init__() + # Set seed for deterministic initialization + torch.manual_seed(seed) self.net1 = nn.Linear(10, 10) self.relu = nn.ReLU() self.net2 = nn.Linear(10, 10) @@ -50,6 +101,7 @@ def ring_send_recv_checkpoint( ): """ Use the transport to send to rank + 1 and receive from rank - 1. + Each rank exchanges its own state_dict with the previous rank. """ next_rank = (rank + 1) % world_size prev_rank = (rank - 1) % world_size @@ -58,15 +110,11 @@ def ring_send_recv_checkpoint( received_checkpoint = transport.recv_checkpoint(prev_rank) else: received_checkpoint = transport.recv_checkpoint(prev_rank) - transport.send_checkpoint([next_rank], received_checkpoint) + transport.send_checkpoint([next_rank], state_dict) return received_checkpoint def _test_pg_transport(self, device) -> None: - # python test/distributed/checkpoint/test_pg_transport.py -k test_pg_transport - print(f"{self.rank=} pid: {os.getpid()} {device=}") - print("in test") - model = SimpleModel().to(device) transport = PGTransport(_get_default_group(), timedelta(seconds=10), device) original_state_dict = model.state_dict() @@ -111,6 +159,48 @@ def _test_pg_transport_with_mixed_content(self, device) -> None: self.assertEqual(state_dict, received_checkpoint) +def _test_pg_transport_with_sharded_tensor(self, device) -> None: + # Set current CUDA device for NCCL + if device.type == "cuda": + torch.cuda.set_device(device) + + state_dict = _create_sharded_tensor_state_dict(self.rank, self.world_size, device) + transport = PGTransport(_get_default_group(), timedelta(seconds=10), device) + print(state_dict) + received_checkpoint = ring_send_recv_checkpoint( + transport=transport, + state_dict=state_dict, + rank=self.rank, + world_size=self.world_size, + ) + print("finished comms") + print(received_checkpoint) + + # Validate that received checkpoint matches what we expect from rank - 1 + prev_rank = (self.rank - 1) % self.world_size + + # Compare rank_scalar (should be from previous rank) + # Note: PGTransport moves received tensors to CPU when no state_dict callback is provided + expected_rank_scalar = torch.tensor(float(prev_rank), device="cpu") + received_rank_scalar = received_checkpoint["rank_scalar"] # type: ignore[index] + print(f"{expected_rank_scalar=} {received_rank_scalar=}") + torch.testing.assert_close(expected_rank_scalar, received_rank_scalar) + + # For ShardedTensor, validate the local shard data matches what prev_rank would have + received_st = received_checkpoint["sharded_tensor"] # type: ignore[index] + global_size = 64 + shard_size = global_size // self.world_size + prev_start_idx = prev_rank * shard_size + prev_end_idx = (prev_rank + 1) * shard_size + expected_local_tensor = torch.arange( + prev_start_idx * 8, prev_end_idx * 8, dtype=torch.float32, device="cpu" + ).reshape(shard_size, 8) + + # Compare the actual tensor data + received_local_tensor = received_st.local_shards()[0].tensor + torch.testing.assert_close(expected_local_tensor, received_local_tensor) + + class PgTransportCPU(MultiProcContinousTest): world_size = 8 timeout: timedelta = timedelta(seconds=20) @@ -133,6 +223,9 @@ def test_pg_transport(self) -> None: def test_pg_transport_with_mixed_content(self) -> None: _test_pg_transport_with_mixed_content(self, self.device) + def test_pg_transport_with_sharded_tensor(self) -> None: + _test_pg_transport_with_sharded_tensor(self, self.device) + class PgTransportCUDA(MultiProcContinousTest): world_size = 2 @@ -160,6 +253,11 @@ def test_pg_transport(self) -> None: def test_pg_transport_with_mixed_content(self) -> None: _test_pg_transport_with_mixed_content(self, self.device) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + def test_pg_transport_with_sharded_tensor(self) -> None: + _test_pg_transport_with_sharded_tensor(self, self.device) + class TestCastTensor(TestCase): def test_cast_tensor_different_dtypes(self): @@ -509,8 +607,5 @@ def test_send_checkpoint_with_cpu_tensors(self): self.assertGreaterEqual(self.mock_work.wait.call_count, 4) -# import fbvscode -# fbvscode.attach_debugger() - if __name__ == "__main__": run_tests() diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index 37bb6def9a94..a42215e0ea0d 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -62,6 +62,9 @@ from torch.utils._pytree import tree_all, tree_all_only +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" + + if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) @@ -79,7 +82,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin): @property def world_size(self) -> int: - return min(4, torch.cuda.device_count()) + return min(4, torch.accelerator.device_count()) def _test_save_load( self, @@ -101,7 +104,7 @@ def _test_save_load( for d_optim in _dist_optim: d_optim.zero_grad() - batch = torch.rand(8, 100, device="cuda") + batch = torch.rand(8, 100, device=device_type) model(batch).sum().backward() dist_model(batch).sum().backward() @@ -188,9 +191,9 @@ def _test_fsdp( def init_model_optim(): if use_dtensor: - device_mesh = init_device_mesh("cuda", (self.world_size,)) + device_mesh = init_device_mesh(device_type, (self.world_size,)) - orig_model = CompositeParamModel(device=torch.device("cuda")) + orig_model = CompositeParamModel(device=torch.device(device_type)) orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4, foreach=True) copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4, foreach=True) if wrapping: @@ -198,7 +201,7 @@ def init_model_optim(): else: strategy = {UnitModule} if use_dtensor: - device_mesh = init_device_mesh("cuda", (self.world_size,)) + device_mesh = init_device_mesh(device_type, (self.world_size,)) dist_model = FSDP( copy.deepcopy(orig_model), auto_wrap_policy=ModuleWrapPolicy(strategy), @@ -258,7 +261,7 @@ def _test_fsdp2( foreach: bool = True, ): def init_model_optim(): - orig_model = CompositeParamModel(device=torch.device("cuda")) + orig_model = CompositeParamModel(device=torch.device(device_type)) orig_optim = optimizer_class( orig_model.parameters(), lr=1e-4, foreach=foreach ) @@ -295,7 +298,7 @@ def test_fsdp2(self) -> None: def _test_ddp(self, use_composable: bool, optimizer_class: type[Optimizer]) -> None: def init_model_optim(): - orig_model = CompositeParamModel(device=torch.device("cuda")) + orig_model = CompositeParamModel(device=torch.device(device_type)) orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4) copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4) if use_composable: @@ -329,7 +332,7 @@ def _test_fsdp_ddp( test_frozen: bool = False, ) -> None: def init_model_optim(): - orig_model = CompositeParamModel(device=torch.device("cuda")) + orig_model = CompositeParamModel(device=torch.device(device_type)) if test_frozen: for param in chain( orig_model.u1.parameters(), orig_model.u2.parameters() @@ -370,7 +373,7 @@ def test_fsdp_ddp(self) -> None: def _test_single_gpu(self, optimizer_class: type[Optimizer]) -> None: def init_model_optim(): - orig_model = CompositeParamModel(device=torch.device("cuda")) + orig_model = CompositeParamModel(device=torch.device(device_type)) orig_optim = optimizer_class(orig_model.parameters(), lr=1e-4) copy_optim = optimizer_class(orig_model.parameters(), lr=1e-4) model_copy = copy.deepcopy(orig_model) @@ -385,7 +388,7 @@ def test_single_gpu(self) -> None: self._test_single_gpu(torch.optim.AdamW) def _test_strict(self, parallelism: str) -> None: - model = CompositeParamModel(device=torch.device("cuda")) + model = CompositeParamModel(device=torch.device(device_type)) if parallelism == "DDP": model = DDP(model) else: @@ -422,8 +425,8 @@ def test_strict(self) -> None: def _test_cpu_offload_full_state_dict( self, optimizer_class: type[Optimizer] ) -> None: - orig_model = CompositeParamModel(device=torch.device("cuda")) - device_mesh = init_device_mesh("cuda", (self.world_size,)) + orig_model = CompositeParamModel(device=torch.device(device_type)) + device_mesh = init_device_mesh(device_type, (self.world_size,)) dist_model = FSDP( copy.deepcopy(orig_model), auto_wrap_policy=ModuleWrapPolicy({UnitModule}), @@ -499,7 +502,7 @@ def test_cpu_offload_full_state_dict(self) -> None: @skip_if_lt_x_gpu(1) def test_activation_ckpt_fqns_ddp(self) -> None: """Tests that activation checkpointing prefixes are removed from module names""" - model = CompositeParamModel(device=torch.device("cuda")) + model = CompositeParamModel(device=torch.device(device_type)) original_keys = get_model_state_dict(model).keys() apply_activation_checkpointing(model) @@ -518,7 +521,7 @@ def test_activation_ckpt_fqns_fsdp1(self) -> None: def _test_activation_ckpt_fqns_fsdp1(self, use_orig_params: bool) -> None: """Tests that activation checkpointing prefixes are removed from module names""" - model = CompositeParamModel(device=torch.device("cuda")) + model = CompositeParamModel(device=torch.device(device_type)) original_keys = get_model_state_dict(model).keys() apply_activation_checkpointing(model) @@ -529,7 +532,7 @@ def _test_activation_ckpt_fqns_fsdp1(self, use_orig_params: bool) -> None: @skip_if_lt_x_gpu(1) def test_extra_state(self) -> None: - model = CompositeParamModel(device=torch.device("cuda")) + model = CompositeParamModel(device=torch.device(device_type)) def get_extra_state(self): return "MyState" @@ -547,21 +550,21 @@ def set_extra_state(self, state): @skip_if_lt_x_gpu(1) def test_non_persistent_buffers(self) -> None: - model = CompositeParamModel(device=torch.device("cuda")) + model = CompositeParamModel(device=torch.device(device_type)) model.register_buffer( - "dont_save_me", torch.rand(100, device="cuda"), persistent=False + "dont_save_me", torch.rand(100, device=device_type), persistent=False ) target_model = copy.deepcopy(model) set_model_state_dict(target_model, get_model_state_dict(target_model)) self.assertEqual(model.state_dict(), get_model_state_dict(target_model)) def _test_broadcast_from_rank0(self, wrapper) -> None: - model = CompositeParamModel(device=torch.device("cuda")) + model = CompositeParamModel(device=torch.device(device_type)) optim = torch.optim.Adam(model.parameters()) fsdp_model = wrapper(copy.deepcopy(model)) fsdp_optim = torch.optim.Adam(fsdp_model.parameters()) - batch = torch.rand(8, 100, device="cuda") + batch = torch.rand(8, 100, device=device_type) model(batch).sum().backward() optim.step() states, optim_states = get_state_dict(model, optim) @@ -631,8 +634,8 @@ def check(equal): @with_comms @skip_if_lt_x_gpu(4) def test_broadcast_from_rank0(self) -> None: - device_mesh = init_device_mesh("cuda", (self.world_size,)) - hsdp_device_mesh = init_device_mesh("cuda", (2, self.world_size // 2)) + device_mesh = init_device_mesh(device_type, (self.world_size,)) + hsdp_device_mesh = init_device_mesh(device_type, (2, self.world_size // 2)) self.run_subtests( { "wrapper": [ @@ -654,8 +657,8 @@ def test_fsdp_root_not_initialized(self) -> None: # This test verifies that FSDP root is not initialized but we should # still be able to get the state_dict without errors because # fsdp_model.state_dict() will trigger the FSDP initialization. - device_mesh = init_device_mesh("cuda", (self.world_size,)) - model = CompositeParamModel(device=torch.device("cuda")) + device_mesh = init_device_mesh(device_type, (self.world_size,)) + model = CompositeParamModel(device=torch.device(device_type)) fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh) fsdp_optim = torch.optim.Adam(fsdp_model.parameters()) get_model_state_dict(fsdp_model) @@ -668,10 +671,9 @@ def test_optim_state_dict_param_matching(self) -> None: # "initial_lr" is added to optim_state_dict, but not to the new optim # We test whether "initial_lr" appear in optim after # set_optimizer_state_dict. - device = "cuda" torch.manual_seed(0) model = nn.Sequential( - *[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)] + *[nn.Linear(4, 4, device=device_type, bias=False) for _ in range(2)] ) for layer in model: fully_shard(layer) @@ -705,11 +707,11 @@ def test_optim_state_dict_param_matching(self) -> None: @with_comms @skip_if_lt_x_gpu(2) def test_flattened_osd(self) -> None: - device_mesh = init_device_mesh("cuda", (self.world_size,)) - model = CompositeParamModel(device=torch.device("cuda")) + device_mesh = init_device_mesh(device_type, (self.world_size,)) + model = CompositeParamModel(device=torch.device(device_type)) fsdp_model = fully_shard(copy.deepcopy(model), mesh=device_mesh) fsdp_optim = torch.optim.AdamW(fsdp_model.parameters()) - batch = torch.rand(8, 100, device="cuda") + batch = torch.rand(8, 100, device=device_type) fsdp_model(batch).sum().backward() fsdp_optim.step() fsdp_optim.zero_grad() @@ -730,7 +732,7 @@ def test_flattened_osd(self) -> None: self.assertEqual(fsdp_optim.state_dict(), fsdp_optim2.state_dict()) def _test_deprecate_partial(self) -> None: - model = CompositeParamModel(device=torch.device("cuda")) + model = CompositeParamModel(device=torch.device(device_type)) model_state_dict1 = get_model_state_dict(model) model_state_dict1 = copy.deepcopy(model_state_dict1) @@ -783,8 +785,8 @@ def _test_deprecate_partial(self) -> None: self.assertEqual(model.l.bias, model_state_dict1["l.bias"]) def _test_deprecate_fsdp_api(self) -> None: - device_mesh = init_device_mesh("cuda", (self.world_size,)) - model = CompositeParamModel(device=torch.device("cuda")) + device_mesh = init_device_mesh(device_type, (self.world_size,)) + model = CompositeParamModel(device=torch.device(device_type)) fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh) with self.assertWarnsRegex( FutureWarning, @@ -823,8 +825,8 @@ def forward(self, input): return output def init_model_optim(): - device_mesh = init_device_mesh("cuda", (self.world_size,)) - orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda")) + device_mesh = init_device_mesh(device_type, (self.world_size,)) + orig_model = TiedEmbeddingModel(10000, 300).to(torch.device(device_type)) orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4) copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-4) dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh) @@ -905,8 +907,12 @@ def test_setting_meta_device_model_broadcasting_and_memory(self) -> None: self.assertEqual(cpu_model_value, meta_model_value) # Memory allocated and reserved are lower due to the change at _distribute_tensors # from view to clone. This test would fail if with view due to higher memory cost. - memory_allocated = torch.cuda.memory_allocated(0) / 1024 / 1024 - memory_reserved = torch.cuda.memory_reserved(0) / 1024 / 1024 + memory_allocated = ( + torch.get_device_module(device_type).memory_allocated(0) / 1024 / 1024 + ) + memory_reserved = ( + torch.get_device_module(device_type).memory_reserved(0) / 1024 / 1024 + ) self.assertTrue(memory_allocated <= 384) self.assertTrue(memory_reserved <= 768) @@ -942,11 +948,11 @@ def test_multi_device_load_model_state_dict(self) -> None: meta_submodel = nn.Linear(4, 4, bias=False) with torch.device("cpu"): cpu_submodel = nn.Linear(4, 4, bias=False) - with torch.device("cuda"): - cuda_submodel = nn.Linear(4, 4, bias=False) + with torch.device(device_type): + acc_submodel = nn.Linear(4, 4, bias=False) - two_device_model_with_meta = nn.Sequential(meta_submodel, cuda_submodel) - two_device_model_without_meta = nn.Sequential(cpu_submodel, cuda_submodel) + two_device_model_with_meta = nn.Sequential(meta_submodel, acc_submodel) + two_device_model_without_meta = nn.Sequential(cpu_submodel, acc_submodel) with torch.device("cpu"): model_to_set = nn.Sequential( @@ -974,7 +980,7 @@ def test_multi_device_load_model_state_dict(self) -> None: def test_state_dict_with_hook_on_keys(self) -> None: with torch.device("meta"): metamodel = FusionEmbedding(4, 4, 4) - with torch.device("cuda"): + with torch.device(device_type): gpumodel = FusionEmbeddingWithHook(4, 4, 4) gpumodel_state_dict = get_model_state_dict(gpumodel) with self.assertRaisesRegex(RuntimeError, "Missing key"): @@ -995,8 +1001,8 @@ def __init__(self): def forward(self, x): return self.fc1(self.fc(x)) - device_mesh = init_device_mesh("cuda", (self.world_size,)) - model = TestModel().cuda() + device_mesh = init_device_mesh(device_type, (self.world_size,)) + model = TestModel().to(device_type) parallelize_module( model, device_mesh, @@ -1014,7 +1020,7 @@ def _test_multi( optim = torch.optim.AdamW(**optim_kwargs) optim.zero_grad() - model(torch.randn(64, 64).cuda()).sum().backward() + model(torch.randn(64, 64, device=device_type)).sum().backward() optim.step() optim.zero_grad() @@ -1067,7 +1073,7 @@ def setUp(self) -> None: @skip_if_lt_x_gpu(1) def test_no_dist(self) -> None: - model = CompositeParamModel(device=torch.device("cuda")) + model = CompositeParamModel(device=torch.device(device_type)) optim = torch.optim.AdamW(model.parameters(), lr=1e-4) self.assertFalse(dist.is_initialized()) diff --git a/test/distributed/fsdp/test_fsdp_uneven.py b/test/distributed/fsdp/test_fsdp_uneven.py index f74f2ed94ebb..d0094ce1de71 100644 --- a/test/distributed/fsdp/test_fsdp_uneven.py +++ b/test/distributed/fsdp/test_fsdp_uneven.py @@ -45,7 +45,7 @@ def _get_ref_results(self, device, model, input, my_lr): def test_one_iteration(self, device): """Test FSDP with uneven divide of parameter shards.""" model = Linear(3, 3, bias=False) - input = torch.rand(8, 3) + input = torch.rand(self.world_size, 3) my_lr = 0.1 ref_forward_output_my_rank, ref_weight_out = self._get_ref_results( diff --git a/test/distributed/pipelining/test_schedule_multiproc.py b/test/distributed/pipelining/test_schedule_multiproc.py index 50aa9ff21ba0..41967a0e5824 100644 --- a/test/distributed/pipelining/test_schedule_multiproc.py +++ b/test/distributed/pipelining/test_schedule_multiproc.py @@ -110,6 +110,137 @@ def test_forward_only(self, ScheduleClass): torch.testing.assert_close(x_clone, out) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize( + "ScheduleClass", + [ScheduleGPipe, Schedule1F1B, ScheduleInterleaved1F1B, ScheduleLoopedBFS], + ) + def test_eval_inference_mode(self, ScheduleClass): + if ScheduleClass in [ScheduleInterleaved1F1B, ScheduleLoopedBFS]: + # Multi-stage schedules + stages_per_rank = 2 + n_stages = stages_per_rank * self.world_size + mod = MultiMLP(d_hid, n_layers=n_stages) + mod.to(self.device) + + x = torch.randn(batch_size, d_hid, device=self.device) + target = torch.randn(batch_size, d_hid, device=self.device) + loss_fn = torch.nn.MSELoss(reduction="sum") + + chunks = 4 + stage_indices = [ + self.rank + i * self.world_size for i in range(stages_per_rank) + ] + submod_names = [f"layers.{i}" for i in stage_indices] + stage_modules = [ + mod.get_submodule(submod_name) for submod_name in submod_names + ] + stages = [ + PipelineStage( + stage_module, + stage_idx, + n_stages, + self.device, + ) + for stage_module, stage_idx in zip(stage_modules, stage_indices) + ] + + # Test with eval() method for inference + schedule = ScheduleClass(stages, chunks, loss_fn=loss_fn, scale_grads=False) + + # Clear gradients + for stage_module in stage_modules: + stage_module.zero_grad() + + if self.rank == 0: + schedule.eval(x) + elif self.rank == self.world_size - 1: + losses = [] + schedule.eval(target=target, losses=losses) + else: + schedule.eval() + + # Check that gradients were NOT computed during eval + grad_computed_eval = False + for stage_module in stage_modules: + for param in stage_module.parameters(): + if param.grad is not None: + grad_computed_eval = True + break + if grad_computed_eval: + break + + # Verify that gradients were not computed during eval + self.assertFalse( + grad_computed_eval, + "Gradients should not be computed during eval()", + ) + + # Verify that losses are still computed during eval + if self.rank == self.world_size - 1: + self.assertTrue( + len(losses) > 0, "Losses should be computed during eval()" + ) + else: + # Single-stage schedules + mod = MultiMLP(d_hid, n_layers=self.world_size) + mod.to(self.device) + + x = torch.randn(batch_size, d_hid, device=self.device) + target = torch.randn(batch_size, d_hid, device=self.device) + loss_fn = torch.nn.MSELoss(reduction="sum") + + chunks = 4 + x_mb = x.chunk(chunks)[0] + + # Create a pipeline + split_spec = mod.split_spec if hasattr(mod, "split_spec") else None + pipe = pipeline( + mod, + mb_args=(x_mb,), + split_spec=split_spec, + ) + + stage = pipe.build_stage( + self.rank, + self.device, + ) + + # Test with eval() method for inference + schedule = ScheduleClass(stage, chunks, loss_fn=loss_fn, scale_grads=False) + + # Get stage module for gradient checking + stage_module = pipe.get_stage_module(self.rank) + stage_module.zero_grad() + + if self.rank == 0: + schedule.eval(x) + elif self.rank == self.world_size - 1: + losses = [] + schedule.eval(target=target, losses=losses) + else: + schedule.eval() + + # Check that gradients were NOT computed during eval + grad_computed_eval = False + for param in stage_module.parameters(): + if param.grad is not None: + grad_computed_eval = True + break + + # Verify that gradients were not computed during eval + self.assertFalse( + grad_computed_eval, + "Gradients should not be computed during eval()", + ) + + # Verify that losses are still computed during eval + if self.rank == self.world_size - 1: + self.assertTrue( + len(losses) > 0, "Losses should be computed during eval()" + ) + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) @@ -1048,6 +1179,5 @@ def test_zero_bubble_with_model_kwargs(self, ScheduleClass): instantiate_parametrized_tests(ScheduleTest) - if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py index 906b7d1a4a52..df3e2ffb3885 100644 --- a/test/distributed/tensor/parallel/test_micro_pipeline_tp.py +++ b/test/distributed/tensor/parallel/test_micro_pipeline_tp.py @@ -494,5 +494,55 @@ def test_dtensor_seq_par(self, shard_dim: int): self.assertNotIn("reduce_scatter_tensor", code) +@instantiate_parametrized_tests +class MicroPipelineTP4GPUTest(TestCase): + def setUp(self): + torch._inductor.config._micro_pipeline_tp = True + + self.rank = 0 + self.world_size = 4 + torch.cuda.set_device("cuda:0") + + store = FakeStore() + dist.init_process_group( + backend="fake", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + + def tearDown(self): + dist.destroy_process_group() + + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @fresh_cache() + def test_extra_collectives(self): + device_mesh = DeviceMesh( + "cuda", + torch.arange(0, self.world_size).view(2, -1), + mesh_dim_names=("tp", "other"), + ) + + def func(inp: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor) -> torch.Tensor: + hidden = all_gather_tensor(inp, 0, (device_mesh, 0)) @ w1.t() + full_hidden = all_gather_tensor(hidden, 0, (device_mesh, 1)) + full_hidden /= full_hidden.pow(2).sum().sqrt() + hidden = reduce_scatter_tensor(full_hidden, "avg", 0, (device_mesh, 1)) + return reduce_scatter_tensor(hidden @ w2.t(), "avg", 0, (device_mesh, 0)) + + inp = torch.rand(8, 10, device="cuda") + w1 = torch.rand(7, 10, device="cuda") + w2 = torch.rand(10, 7, device="cuda") + + with _test_mode(group_names={device_mesh["tp"].get_group().group_name}): + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, inp, w1, w2) + + self.assertIn("fused_all_gather_matmul", code) + self.assertIn("all_gather_into_tensor", code) + self.assertIn("fused_matmul_reduce_scatter", code) + self.assertIn("reduce_scatter_tensor", code) + + if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/test_api.py b/test/distributed/tensor/test_api.py index dd9f163ab4fa..a4efd6d5b6be 100644 --- a/test/distributed/tensor/test_api.py +++ b/test/distributed/tensor/test_api.py @@ -48,7 +48,7 @@ def world_size(self) -> int: def test_distribute_tensor_rank(self): comm_mode = CommDebugMode() - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] for requires_grad in [True, False]: @@ -134,7 +134,7 @@ def test_distribute_tensor_errors(self): @with_comms def test_distribute_tensor_uneven_sharding(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() input_sizes_and_shard_dims = [ ((self.world_size * 3 + 1, 3, 3), 0), ((self.world_size * 3 + 2, 3, 3), 0), @@ -156,7 +156,7 @@ def test_distribute_tensor_uneven_sharding(self): @with_comms def test_distribute_module(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() # fully shard all linear modules on dim 0 module_to_shard = MyModel(5 * self.world_size, 20, device=self.device_type) shard_spec = [Shard(0)] @@ -219,7 +219,7 @@ def shard_fn(name, module, device_mesh): @with_comms def test_distribute_module_input_fn_output_fn(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() # fully replicate all linear modules module_to_replicate = MyModel(20, 1, device=self.device_type) @@ -264,7 +264,7 @@ def replicate_input_fn(mod, inputs, device_mesh): @with_comms def test_distribute_module_input_fn_output_fn_warning(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() # fully replicate all linear modules module_to_replicate = MyModel(20, 1, device=self.device_type) @@ -292,7 +292,7 @@ def output_fn(outputs, device_mesh): @with_comms def test_distribute_module_casting(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() # check DTensor casting dt = DTensor.from_local(torch.rand(10), device_mesh, [Replicate()]) @@ -335,7 +335,7 @@ def test_distribute_module_casting(self): def test_distribute_module_meta(self): # If the model is too big, the user may first the create entire model on the meta device and then initialize # it on the device in the partition function. - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() # fully shard all parameters on dim 0 module_to_shard = MyModel(5 * self.world_size, 20, device="meta") diff --git a/test/distributed/tensor/test_convolution_ops.py b/test/distributed/tensor/test_convolution_ops.py index b6588c2ad95e..d249a6d2ff77 100644 --- a/test/distributed/tensor/test_convolution_ops.py +++ b/test/distributed/tensor/test_convolution_ops.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn -from torch.distributed import DeviceMesh, init_device_mesh +from torch.distributed import DeviceMesh from torch.distributed.tensor import ( distribute_module, distribute_tensor, @@ -48,7 +48,7 @@ def world_size(self) -> int: @with_comms def test_downsampling_convolution(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(3)] input_list = torch.rand(ITER_TIME, 7, 3, 512, 1024) @@ -118,7 +118,7 @@ def test_downsampling_convolution(self): @with_comms @skip_if_lt_x_gpu(2) def test_depthwise_convolution(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(3)] input_list = torch.rand(ITER_TIME, 7, 256, 128, 256) @@ -186,9 +186,7 @@ def test_depthwise_convolution(self): @with_comms @skip_if_lt_x_gpu(2) def test_conv_backward_none_grad_inp(self): - device_mesh = init_device_mesh( - device_type=self.device_type, mesh_shape=(self.world_size,) - ) + device_mesh = self.build_device_mesh() conv = nn.Conv2d(64, 64, 3, padding=1).train() x = torch.randn(1, 64, 32, 32) x_dt = DTensor.from_local(x, device_mesh, [Replicate()]) diff --git a/test/distributed/tensor/test_dtensor.py b/test/distributed/tensor/test_dtensor.py index b82661454bfc..73f4b709103f 100644 --- a/test/distributed/tensor/test_dtensor.py +++ b/test/distributed/tensor/test_dtensor.py @@ -11,7 +11,6 @@ import torch import torch.nn.functional as F from torch.distributed._functional_collectives import AsyncCollectiveTensor -from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor import ( DeviceMesh, distribute_tensor, @@ -61,7 +60,7 @@ def reset_parameters(self, *args, **kwargs): class DTensorTest(DTensorTestBase): @with_comms def test_dtensor_constructor(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] local_tensor = torch.randn(3, 3, requires_grad=True) @@ -149,7 +148,7 @@ def test_modules_w_meta_dtensor(self): @with_comms def test_dtensor_stride(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard0_spec = [Shard(0)] local_tensor = torch.randn(4, 8) dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard0_spec) @@ -172,7 +171,7 @@ def test_dtensor_stride(self): @with_comms def test_from_local(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) @@ -209,8 +208,7 @@ def test_from_local(self): @with_comms def test_from_local_uneven_sharding(self): - mesh_shape = (self.world_size,) - device_mesh = init_device_mesh(self.device_type, mesh_shape) + device_mesh = self.build_device_mesh() uneven_dim0_size = self.world_size + 1 global_tensor = torch.randn(uneven_dim0_size, 2) @@ -235,8 +233,7 @@ def test_from_local_uneven_sharding(self): @with_comms def test_from_local_uneven_sharding_raise_error(self): - mesh_shape = (self.world_size,) - device_mesh = init_device_mesh(self.device_type, mesh_shape) + device_mesh = self.build_device_mesh() uneven_dim0_size = self.world_size + 1 global_tensor = torch.randn(uneven_dim0_size, 2) @@ -270,7 +267,7 @@ def test_from_local_uneven_sharding_raise_error(self): @with_comms def test_from_local_negative_dim(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(-1)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) @@ -278,7 +275,7 @@ def test_from_local_negative_dim(self): @with_comms def test_to_local(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = (Shard(0),) local_tensor_with_grad = torch.randn( 3, 3, device=self.device_type, requires_grad=True @@ -338,7 +335,7 @@ def test_to_local(self): @with_comms def test_to_local_grad_hint(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = (Shard(0),) global_tensor = torch.ones(8, 3, requires_grad=True) @@ -363,7 +360,7 @@ def test_to_local_grad_hint(self): @with_comms def test_full_tensor_sync(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = (Shard(0),) global_tensor = torch.ones(8, 3, requires_grad=True) @@ -374,7 +371,7 @@ def test_full_tensor_sync(self): @with_comms def test_full_tensor_grad_hint(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = (Shard(0),) global_tensor = torch.ones(8, 3, requires_grad=True) @@ -387,7 +384,7 @@ def test_full_tensor_grad_hint(self): @with_comms def test_dtensor_new_empty_strided(self): - device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + device_mesh = self.build_device_mesh() local_tensor = torch.randn(8, 8, requires_grad=True, device=self.device_type) my_dtensor = distribute_tensor(local_tensor, device_mesh, [Shard(0)]) new_strided_dtensor = my_dtensor.new_empty_strided( @@ -413,7 +410,7 @@ def test_dtensor_async_output(self): # Tests that if the output of some dtensor operations isn't used in any compute, # the output should be an AsyncCollectiveTensor (representing the fact that # we haven't synced the collective yet). - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() def fn(dt): dt_out_redistribute = dt.redistribute(mesh, [Replicate()], async_op=True) @@ -453,7 +450,7 @@ def fn(dt): @with_comms def test_from_local_then_to_local(self): # this test ensure end to end from torch.Tensor -> dist tensor -> torch.Tensor works - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] # step 1. construct from construct local tensor @@ -485,7 +482,7 @@ def test_from_local_then_to_local(self): @with_comms def test_dtensor_spec_read_only_after_set(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) @@ -497,7 +494,7 @@ def test_dtensor_spec_read_only_after_set(self): @with_comms def test_dtensor_spec_hash(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] local_tensor = torch.randn(3, 3) local_tensor2 = torch.randn(3, 3) @@ -517,7 +514,7 @@ def test_dtensor_spec_hash(self): @with_comms def test_dtensor_properties(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements) @@ -571,7 +568,7 @@ def test_dtensor_save_load_import(self): @with_comms def test_shard_tensor(self): ws = self.world_size - device_mesh = DeviceMesh(self.device_type, list(range(ws))) + device_mesh = self.build_device_mesh() full_tensor = torch.arange(ws * ws).reshape(ws, ws) # Shard by row @@ -622,7 +619,7 @@ def sub_mesh_assert_equal(self, mesh, exp_in_mesh, exp_out_of_mesh, tensor): @with_comms def test_dtensor_device_mesh_device_conversion(self): # construct a cuda device mesh - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() # construct from a cpu local tensor with cuda device mesh # should automatically convert the dist tensor to cuda @@ -634,14 +631,14 @@ def test_dtensor_device_mesh_device_conversion(self): @with_comms def test_dtensor_api_device_mesh_context_manager(self): - with DeviceMesh(self.device_type, list(range(self.world_size))) as mesh: + with self.build_device_mesh() as mesh: placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local( local_tensor, device_mesh=mesh, placements=placements ) - with DeviceMesh(self.device_type, list(range(self.world_size))): + with self.build_device_mesh(): placements = [Shard(0)] local_tensor = torch.randn(3, 3) sharded_tensor = DTensor.from_local(local_tensor, placements=placements) @@ -651,7 +648,7 @@ def test_dtensor_api_device_mesh_context_manager(self): replica_tensor.size(), torch.Size([3 * self.world_size, 3]) ) - with DeviceMesh(self.device_type, torch.arange(self.world_size)): + with self.build_device_mesh(): placements = [Shard(0)] global_shape = torch.Size([3 * self.world_size, 3]) global_tensor = torch.randn(global_shape) @@ -837,7 +834,7 @@ def test_redistribute_sub_mesh(self): @with_comms def test_implicit_replication(self): - mesh = init_device_mesh(self.device_type, (self.world_size,)) + mesh = self.build_device_mesh() local_tensor1 = torch.ones(4, 3) sharded_dtensor = DTensor.from_local(local_tensor1, mesh, [Shard(0)]) @@ -853,7 +850,7 @@ def test_implicit_replication(self): @with_comms def test_auto_implicit_replication(self): - mesh = init_device_mesh(self.device_type, (self.world_size,)) + mesh = self.build_device_mesh() local_tensor = torch.ones(self.world_size, 3, device=self.device_type) sharded_dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)]) @@ -879,7 +876,7 @@ def add_scalar_tensor_with_dtensor(): @with_comms def test_metadata_consistency_check(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements = [Shard(0)] # Create a local tensor with specific metadata and check dtype change @@ -941,7 +938,7 @@ def _create_tensor(self, size): @with_comms def test_split_tensor_1D(self) -> None: - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() shard_placement = Shard(0) for size in range(8): diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index a26cf5da144f..86f1e9d8fb47 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -166,6 +166,8 @@ def forward(self, b_buffer, x): return (view_as_1,)""", # noqa: B950 ) + # During tracing, sharding propagation cache is skipped, so an extra dry run for + # add is performed in _propagate_tensor_meta_non_cached, hence add_1 instead of add self.assertExpectedInline( str(ep.run_decompositions({}).graph_module.code).strip(), """\ @@ -173,8 +175,8 @@ def forward(self, b_parametrizations_buffer_original0, x): _assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None _to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None - add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None - view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None + add_1 = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None + view_1 = torch.ops.aten.view.default(add_1, [4, 4]); add_1 = None return (view_1,)""", # noqa: B950 ) @@ -276,6 +278,49 @@ def fn(x): res = opt_fn(x) self.assertEqual(res, ref) + @skipIfHpu + def test_dtensor_dynamic_slice(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + # test passing in DTensor as inputs/outputs and run some tensor computation + def fn(x): + return [ + t.redistribute( + device_mesh=x.device_mesh, placements=[Replicate()] + ).to_local()[0] + for t in torch.tensor_split(x, 2) + ] + + x = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False) + ref = fn(x) + + opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True, dynamic=True) + res = opt_fn(x) + self.assertEqual(res, ref) + + @skipIfHpu + def test_dtensor_dynamic_cat(self): + # RESET COUNTS + + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + # test passing in tuple of DTensors as + def fn(x, y): + return ( + torch.cat((x, y), dim=0) + .redistribute(device_mesh=x.device_mesh, placements=[Replicate()]) + .to_local()[0] + ) + + x = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False) + y = DTensor.from_local(torch.rand(4, 4), mesh, [Shard(0)], run_check=False) + torch._dynamo.mark_dynamic(x, 0) + ref = fn(x, y) + + opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) + res = opt_fn(x, y) + self.assertEqual(res, ref) + def test_dtensor_attribute_access_on_intermediate(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) diff --git a/test/distributed/tensor/test_dtensor_ops.py b/test/distributed/tensor/test_dtensor_ops.py index ba43335d1ddc..1e94cb7e359b 100644 --- a/test/distributed/tensor/test_dtensor_ops.py +++ b/test/distributed/tensor/test_dtensor_ops.py @@ -119,9 +119,6 @@ def wrapped(fn): xfail("cholesky_inverse"), xfail("cholesky_solve"), xfail("chunk"), - xfail("clamp"), - xfail("clamp_max"), - xfail("clamp_min"), xfail("combinations"), xfail("complex"), xfail("constant_pad_nd"), @@ -317,7 +314,6 @@ def wrapped(fn): xfail("nn.functional.multi_head_attention_forward"), xfail("nn.functional.multilabel_margin_loss"), xfail("nn.functional.multilabel_soft_margin_loss"), - xfail("nn.functional.normalize"), xfail("nn.functional.pad", "constant"), xfail("nn.functional.pad", "reflect"), xfail("nn.functional.pad", "replicate"), diff --git a/test/distributed/tensor/test_experimental_ops.py b/test/distributed/tensor/test_experimental_ops.py index d5d7f2406adb..ec4229a47b19 100644 --- a/test/distributed/tensor/test_experimental_ops.py +++ b/test/distributed/tensor/test_experimental_ops.py @@ -4,7 +4,7 @@ import torch import torch.distributed as dist -from torch.distributed.tensor import DeviceMesh, distribute_tensor, Replicate +from torch.distributed.tensor import distribute_tensor, Replicate from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -24,7 +24,7 @@ def world_size(self) -> int: @with_comms def test_slice(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Replicate()] input_list = torch.rand(ITER_TIME, 1024, 10) @@ -76,7 +76,7 @@ def test_slice(self): @with_comms def test_bernoulli(self): rank = dist.get_rank() - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Replicate()] input_list = torch.rand(ITER_TIME, 1024, 10) @@ -138,7 +138,7 @@ def test_bernoulli(self): @with_comms def test_nll(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Replicate()] pred_list = torch.rand(ITER_TIME, 1024, 10) diff --git a/test/distributed/tensor/test_init.py b/test/distributed/tensor/test_init.py index 540994954833..4212b6fc2c9b 100644 --- a/test/distributed/tensor/test_init.py +++ b/test/distributed/tensor/test_init.py @@ -37,7 +37,7 @@ def world_size(self): def _run_init_op(self, init_op, dist_init_op, eq_op, *args, **kwargs): # 1d mesh test - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() placements_list = [[Shard(0)], [Shard(1)], [Shard(2)], [Replicate()]] # even sharding @@ -132,7 +132,7 @@ def test_zeros(self): @with_comms def test_zeros_full_mesh(self): # construct a cuda device 1d mesh - mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh = self.build_device_mesh() placements = [Shard(0)] size = [32, 3] dist_tensor = zeros(size, device_mesh=mesh, placements=placements) diff --git a/test/distributed/tensor/test_math_ops.py b/test/distributed/tensor/test_math_ops.py index 79a7112e0f19..93ce80f18ee1 100644 --- a/test/distributed/tensor/test_math_ops.py +++ b/test/distributed/tensor/test_math_ops.py @@ -572,6 +572,104 @@ def forward(self, tokens): f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}" ) + @with_comms + def test_rms_norm_bwd(self): + device_mesh = self.build_device_mesh() + + # NLP example from pytorch docs + batch, sentence_length, embedding_dim = 20, 5, 10 + norm_shape_idx_list = list(range(3)) + shard_dims = [0] # non-first dimensional sharding is not supported + elementwise_affine_list = [False, True] + test_config_list = list( + itertools.product(shard_dims, norm_shape_idx_list, elementwise_affine_list) + ) + + # normalized shape is a torch.Size object + for shard_dim, norm_idx, elementwise_affine in test_config_list: + x = torch.rand( + batch, + sentence_length, + embedding_dim, + device=self.device_type, + requires_grad=True, + ) + normalized_shape = x.shape[norm_idx:] + rms_norm = torch.nn.RMSNorm( + normalized_shape, + elementwise_affine=elementwise_affine, + device=self.device_type, + ) + rms_norm_local = copy.deepcopy(rms_norm).to(self.device_type) + + def _replicate_fn(name, module, device_mesh): + for name, param in module.named_parameters(): + if name == "weight": + param_dist = torch.nn.Parameter( + distribute_tensor(param, device_mesh, [Replicate()]) + ) + module.register_parameter(name, param_dist) + + rms_norm_dist = distribute_module(rms_norm, device_mesh, _replicate_fn) + + if elementwise_affine: + self.assertEqual( + rms_norm_local.weight, rms_norm_dist.weight.full_tensor() + ) + + x_local = x.detach().clone().requires_grad_(True) + x_dist = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) + self.assertEqual(x_local, x_dist.full_tensor()) + + y_local = rms_norm_local(x_local) + # make sure that backward rms norm does not introduce extra collectives + comm_mode = CommDebugMode() + with comm_mode: + y_dist = rms_norm_dist(x_dist) + y_dist.sum().backward() + + # TODO: forward pass is sharding strategy is generated from composite, hence 1 more collective than layer_norm + # see: https://github.com/pytorch/pytorch/pull/158716#issuecomment-3096012679 + expected_fwd_comm = 0 if shard_dim < norm_idx else 2 + + self.assertEqual( + sum(comm_mode.comm_module_counts["Global"]["forward"].values()), + expected_fwd_comm, + f"comm count={comm_mode.get_total_counts()}, " + f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", + ) + + self.assertEqual(y_local, y_dist.full_tensor()) + + # backward step + y_local.sum().backward() + + expected_bwd_comm = 0 if shard_dim < norm_idx else 1 + + self.assertEqual( + sum(comm_mode.comm_module_counts["Global"]["backward"].values()), + expected_bwd_comm, + f"comm count={comm_mode.get_total_counts()}, " + f"shard_dim={shard_dim}, norm_shape={normalized_shape}, elem_affine={elementwise_affine}", + ) + + if elementwise_affine: + # if input is sharded on any outer dimension, the gradient of weight + # should be Partial + dim_map = x_dist._spec.dim_map + outer_dims = range(norm_idx) + needs_reduction = any(dim_map[d] >= 0 for d in outer_dims) + self.assertEqual( + is_tensor_partial(rms_norm_dist.weight.grad._spec), + needs_reduction, + ) + self.assertEqual( + rms_norm_local.weight.grad, + rms_norm_dist.weight.grad.full_tensor(), + ) + + self.assertEqual(x_local.grad, x_dist.grad.full_tensor()) + @with_comms def test_topk(self): device_mesh = self.build_device_mesh() @@ -811,6 +909,42 @@ def apply_rotary_emb(xq, freqs_cis): self.assertEqual(dtensor_grad, xq.grad) + @with_comms + def test_histc(self): + # TODO - nicer to use parametrize here so its easy to run one sub-test by name, + # but its too slow (10sec per process-group init) -> switch to MultiProcessContinuousTest + device_mesh = self.build_device_mesh() + comm_mode = CommDebugMode() + tensor = torch.randn(12, 8, 8, requires_grad=True) + for min_max_specified in (True, False): + for placement in [Shard(0), Shard(1), Shard(2), Replicate()]: + min_ = tensor.min().item() + max_ = tensor.max().item() + global_bins = ( + tensor.histc(min=min_, max=max_) + if min_max_specified + else tensor.histc() + ) + + dtensor = distribute_tensor(tensor, device_mesh, (placement,)) + with comm_mode: + out_dt = ( + dtensor.histc(min=min_, max=max_) + if min_max_specified + else dtensor.histc() + ) + + if placement.is_shard() and not min_max_specified: + self.assertEqual(comm_mode.get_total_counts(), 1) + self.assertEqual( + comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], 1 + ) + else: + self.assertEqual(comm_mode.get_total_counts(), 0) + + out_full = out_dt.full_tensor() + self.assertEqual(global_bins, out_full) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/test_matrix_ops.py b/test/distributed/tensor/test_matrix_ops.py index d0f8482c0cf5..e9baf2102b25 100644 --- a/test/distributed/tensor/test_matrix_ops.py +++ b/test/distributed/tensor/test_matrix_ops.py @@ -7,7 +7,7 @@ import torch import torch.nn.functional as F -from torch.distributed import DeviceMesh, init_device_mesh +from torch.distributed import init_device_mesh from torch.distributed.tensor import ( distribute_tensor, DTensor, @@ -19,7 +19,12 @@ from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM90OrLater from torch.testing._internal.common_device_type import E4M3_MAX_POS, e4m3_type -from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, + TEST_WITH_ROCM, +) from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, skip_unless_torch_gpu, @@ -47,7 +52,7 @@ def scale_for_fp8( class DistMatrixOpsTest(DTensorTestBase): @with_comms def test_addmm(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] replica_spec = [Replicate()] @@ -64,7 +69,7 @@ def test_addmm(self): @with_comms def test_addmm_empty_operand(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] replica_spec = [Replicate()] @@ -81,7 +86,7 @@ def test_addmm_empty_operand(self): @with_comms def test_addmm_auto_redistribute(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard0_spec = [Shard(0)] shard1_spec = [Shard(1)] replica_spec = [Replicate()] @@ -112,7 +117,7 @@ def test_addmm_auto_redistribute(self): @with_comms def test_mm(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard0_spec = Shard(0) shard1_spec = Shard(1) replica_spec = Replicate() @@ -147,7 +152,7 @@ def test_placement_comb( "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", ) def test_scaled_mm(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shrd0 = Shard(0) shrd1 = Shard(1) repl = Replicate() @@ -217,7 +222,7 @@ def test_scaled_mm(self): @with_comms def test_matmul(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() dim = 128 x = torch.randn(8, dim) A = torch.randn(dim, dim) @@ -236,7 +241,7 @@ def test_matmul(self): @with_comms def test_t(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] tensor_to_transpose = torch.randn(12, 8, requires_grad=True) @@ -250,7 +255,7 @@ def test_t(self): @with_comms def test_t_partial(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() a = torch.randn(12, 8) b = torch.randn(8, 4) @@ -275,7 +280,7 @@ def test_t_partial(self): @with_comms @skip_unless_torch_gpu def test_baddbmm(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() tensor = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True) batch_1 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True) batch_2 = torch.rand(4, 8, 8, device=self.device_type, requires_grad=True) @@ -339,7 +344,7 @@ def test_placement_comb( @with_comms def test_bmm(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() mat1 = torch.rand(4, 8, 4, device=self.device_type, requires_grad=True) mat2 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True) local_result = torch.bmm(mat1, mat2) @@ -384,7 +389,7 @@ def test_placement_comb( @with_comms @skip_unless_torch_gpu def test_scaled_dot_product_attention(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() comm_mode = CommDebugMode() # bsz, n_heads, slen, head_dim query = torch.rand( @@ -487,7 +492,7 @@ def test_tensordot_shampoo(self): """ Create a simple test for Shampoo's use case. """ - device_mesh = init_device_mesh(self.device_type, (self.world_size,)) + device_mesh = self.build_device_mesh() local_a = torch.randn(4, 4) local_b = torch.randn(4, 15) @@ -508,40 +513,78 @@ def test_tensordot_shampoo(self): @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") @with_comms @skip_unless_torch_gpu - def test_grouped_mm(self): + @parametrize( + "kwargs", + [ + { + # 2D x 3D case from MoE layer + "inp_shape": (64, 16), + "w1_shape": (2, 16, 32), + "w2_shape": (2, 32, 16), + "inp_placements": [Replicate()], + "w1_placements": [Shard(2)], + "w2_placements": [Shard(1)], + "expected_comm_counts_fwd": 0, + "expected_comm_counts_bwd": 1, + "expected_out_placements": [Partial()], + }, + { + # Case that would have invalid strides on inp * mat1 when sharded + "inp_shape": (64, 16), + "w1_shape": (2, 16, 16), + "w2_shape": (2, 16, 16), + "inp_placements": [Replicate()], + "w1_placements": [Shard(2)], + "w2_placements": [Shard(1)], + "expected_comm_counts_fwd": 2, + "expected_comm_counts_bwd": 4, + "expected_out_placements": [Replicate()], + }, + ], + ) + def test_grouped_mm(self, kwargs): # TODO: torch._grouped_mm can take inputs of dimension (2D, 3D) x (2D, 3D) - # Here we only test the 2D x 3D Tensor Parallel use case in an MoE layer. # More tests need to be added. - device_mesh = init_device_mesh(self.device_type, (self.world_size,)) + device_mesh = self.build_device_mesh() comm_mode = CommDebugMode() dtype = torch.bfloat16 - inp = torch.rand( - 64, 16, device=self.device_type, dtype=dtype, requires_grad=True + *kwargs["inp_shape"], + device=self.device_type, + dtype=dtype, + requires_grad=True, ) w1 = torch.rand( - 2, 16, 32, device=self.device_type, dtype=dtype, requires_grad=True + *kwargs["w1_shape"], + device=self.device_type, + dtype=dtype, + requires_grad=True, ) w2 = torch.rand( - 2, 32, 16, device=self.device_type, dtype=dtype, requires_grad=True + *kwargs["w2_shape"], + device=self.device_type, + dtype=dtype, + requires_grad=True, ) offs = torch.tensor([16, 64], device=self.device_type, dtype=torch.int32) h = torch._grouped_mm(inp, w1, offs=offs) out = torch._grouped_mm(h, w2, offs=offs) - dist_inp = distribute_tensor(inp, device_mesh, [Replicate()]) + dist_inp = distribute_tensor(inp, device_mesh, kwargs["inp_placements"]) # colwise sharded - dist_w1 = distribute_tensor(w1, device_mesh, [Shard(2)]) + dist_w1 = distribute_tensor(w1, device_mesh, kwargs["w1_placements"]) # rowwise sharded - dist_w2 = distribute_tensor(w2, device_mesh, [Shard(1)]) + dist_w2 = distribute_tensor(w2, device_mesh, kwargs["w2_placements"]) dist_offs = distribute_tensor(offs, device_mesh, [Replicate()]) with comm_mode: dist_h = torch._grouped_mm(dist_inp, dist_w1, offs=dist_offs) dist_out = torch._grouped_mm(dist_h, dist_w2, offs=dist_offs) - self.assertEqual(comm_mode.get_total_counts(), 0) - self.assertTrue(dist_out.placements[0].is_partial()) + self.assertEqual( + comm_mode.get_total_counts(), kwargs["expected_comm_counts_fwd"] + ) + self.assertEqual(dist_out.placements, kwargs["expected_out_placements"]) self.assertEqual(dist_out.full_tensor(), out) out_grad = torch.ones_like(out) @@ -552,15 +595,19 @@ def test_grouped_mm(self): with comm_mode: dist_out.backward(dist_out_grad) - self.assertEqual(comm_mode.get_total_counts(), 1) + self.assertEqual( + comm_mode.get_total_counts(), kwargs["expected_comm_counts_bwd"] + ) self.assertEqual( comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], - 1, + kwargs["expected_comm_counts_bwd"], ) self.assertEqual(dist_inp.grad.full_tensor(), inp.grad) self.assertEqual(dist_w1.grad.full_tensor(), w1.grad) self.assertEqual(dist_w2.grad.full_tensor(), w2.grad) +instantiate_parametrized_tests(DistMatrixOpsTest) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/test_optimizers.py b/test/distributed/tensor/test_optimizers.py index 7e69f362183d..c876f28e165b 100644 --- a/test/distributed/tensor/test_optimizers.py +++ b/test/distributed/tensor/test_optimizers.py @@ -5,7 +5,6 @@ import torch import torch.nn as nn from torch.distributed.tensor import ( - DeviceMesh, distribute_module, distribute_tensor, DTensor, @@ -89,7 +88,7 @@ def test_optimizer_foreach_supported_types_include_DTensor(self): @with_comms def test_adam_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() # lr as a Tensor is not supported for capturable=False and foreach=True adam_float_lr_configs = [ @@ -148,7 +147,7 @@ def test_adam_1d_sharding(self): @with_comms def test_adamw_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() # lr as a Tensor is not supported for capturable=False and foreach=True adamw_float_lr_configs = [ @@ -224,7 +223,7 @@ def test_adamw_1d_sharding(self): @with_comms def test_sgd_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() sgd_configs = [ {"lr": 0.1, "foreach": False}, @@ -264,7 +263,7 @@ def test_sgd_1d_sharding(self): @with_comms def test_adagrad_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() adagrad_configs = [ {"lr": 0.1, "foreach": False}, @@ -320,7 +319,7 @@ def test_adagrad_1d_sharding(self): @with_comms def test_RMSprop_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() RMSprop_configs = [ {"lr": 0.1, "foreach": False}, @@ -387,7 +386,7 @@ def test_RMSprop_1d_sharding(self): @with_comms def test_adadelta_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() adadelta_configs = [ {"lr": 0.1, "foreach": False}, @@ -431,7 +430,7 @@ def test_adadelta_1d_sharding(self): @with_comms def test_nadam_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() nadam_configs = [ {"lr": 0.1, "foreach": False}, @@ -468,7 +467,7 @@ def test_nadam_1d_sharding(self): @with_comms def test_radam_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() radam_configs = [ {"lr": 0.1, "foreach": False}, @@ -508,7 +507,7 @@ def test_radam_1d_sharding(self): @with_comms def test_adamax_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() adamax_configs = [ {"lr": 0.1, "foreach": False}, @@ -552,7 +551,7 @@ def test_adamax_1d_sharding(self): @with_comms def test_asgd_1d_sharding(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() asgd_configs = [ {"lr": 0.1, "foreach": False}, diff --git a/test/distributed/tensor/test_redistribute.py b/test/distributed/tensor/test_redistribute.py index b56f32dbcaea..fe07b0dd6a24 100644 --- a/test/distributed/tensor/test_redistribute.py +++ b/test/distributed/tensor/test_redistribute.py @@ -40,7 +40,7 @@ def world_size(self): @parametrize("dtype", [torch.float32, torch.cfloat]) def test_shard_to_replicate_forward_backward(self, dtype): # 1) test shard -> replicate forward - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() replica_spec = [Replicate()] input_sizes_and_shard_dim = [ @@ -82,7 +82,7 @@ def test_shard_to_replicate_forward_backward(self, dtype): @with_comms def test_replicate_to_replicate_forward_backward(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() replica_spec = [Replicate()] local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) @@ -111,7 +111,7 @@ def test_replicate_to_replicate_forward_backward(self): @with_comms @parametrize("dtype", [torch.float32, torch.cfloat]) def test_replicate_to_local_partial_grad(self, dtype): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() replica_spec = [Replicate()] local_tensor = torch.randn( 12, 3, device=self.device_type, requires_grad=True, dtype=dtype @@ -132,7 +132,7 @@ def test_replicate_to_local_partial_grad(self, dtype): @with_comms def test_replicate_to_shard_forward_backward(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() replica_spec = [Replicate()] input_sizes_and_shard_dim = [ @@ -185,7 +185,7 @@ def test_partial_to_replicate_forward_backward(self, dtype): # placement (i.e. user can't reshard to partial), we do allow # replicate to partial internally, and also partial to replicate # backward should work as expected - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() partial_local = torch.ones( 12, 3, device=self.device_type, requires_grad=True, dtype=dtype ) @@ -220,7 +220,7 @@ def test_partial_to_replicate_forward_backward(self, dtype): @with_comms def test_replicate_to_replicate_forward_backward_datatype_conversion(self): - device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,)) + device_mesh = self.build_device_mesh() replica_spec = [Replicate()] forward_datatypes = [ @@ -277,7 +277,7 @@ def test_replicate_to_replicate_forward_backward_datatype_conversion(self): @with_comms def test_shard_to_replicate_forward_backward_datatype_conversion(self): - device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,)) + device_mesh = self.build_device_mesh() replica_spec = [Replicate()] shard_dim_and_input_sizes = [ @@ -349,7 +349,7 @@ def test_shard_to_replicate_forward_backward_datatype_conversion(self): @with_comms def test_replicate_to_partial(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) partial_spec = Partial() replica_spec = Replicate() @@ -398,7 +398,7 @@ def test_replicate_to_partial(self): @with_comms @parametrize("dtype", [torch.float32, torch.cfloat]) def test_partial_to_shard(self, dtype): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() partial_spec = [Partial()] my_rank = device_mesh.get_rank() @@ -453,7 +453,7 @@ def test_partial_to_shard(self, dtype): @with_comms def test_redistribute_negative_shard_dim(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True) shard_spec = [Shard(1)] shard_minus_spec = [Shard(-1)] @@ -491,7 +491,7 @@ def test_redistribute_uneven_sharding(self): @parametrize("dtype", [torch.float32, torch.cfloat]) def test_redistribute_shard_dim_change(self, dtype): # test 1d device mesh - mesh_1d = DeviceMesh(self.device_type, torch.arange(self.world_size)) + mesh_1d = self.build_device_mesh() data_to_test = [ # evenly sharded case torch.randn((8, 8), device=self.device_type, dtype=dtype), diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index 3c0e65809c7c..0e75748be8a3 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -2,7 +2,6 @@ # Owner(s): ["oncall: distributed"] import torch -import torch.distributed._functional_collectives as funcol from torch.distributed.tensor import ( DeviceMesh, distribute_tensor, @@ -26,7 +25,7 @@ class DistTensorOpsTest(DTensorTestBase): @with_comms def test_aten_contiguous(self): # this op not covered by dtensor_ops - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() self._test_op( mesh, lambda x: torch.ops.aten.contiguous(x), @@ -35,7 +34,7 @@ def test_aten_contiguous(self): @with_comms def test_detach(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] tensor_to_detach = torch.randn(12, 8, requires_grad=True) @@ -45,7 +44,7 @@ def test_detach(self): @with_comms def test_clone(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() specs = [[Replicate()], [Shard(0)]] tensor_to_clone = torch.randn(12, 8, requires_grad=True) for spec in specs: @@ -55,8 +54,48 @@ def test_clone(self): self.assertEqual(cloned_mat.to_local(), mat.to_local()) @with_comms - def test_contiguous(self): + def test_copy_(self): device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + + # basic test + src_tensor = torch.randn((12, 12)) + dst_tensor = torch.zeros(12, 12) + src_specs = [[Replicate()], [Shard(0)]] + dst_specs = [[Replicate()], [Shard(0)]] + for dst_spec, src_spec in zip(dst_specs, src_specs): + src_dtensor = distribute_tensor(src_tensor, device_mesh, dst_spec) + dst_dtensor = distribute_tensor(dst_tensor, device_mesh, src_spec) + dst_dtensor.copy_(src_dtensor) + dst_tensor.copy_(src_tensor) + self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) + + # simple broadcasting + src_tensor = torch.randn((128,)) + dst_tensor = torch.zeros(128, 128) + src_specs = [[Replicate()], [Shard(0)]] + dst_specs = [[Replicate()], [Shard(1)]] + for dst_spec, src_spec in zip(dst_specs, src_specs): + src_dtensor = distribute_tensor(src_tensor, device_mesh, src_spec) + dst_dtensor = distribute_tensor(dst_tensor, device_mesh, dst_spec) + dst_dtensor.copy_(src_dtensor) + dst_tensor.copy_(src_tensor) + self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) + + # The src specs in this case are designed to not be compatible with the dst_specs, redistribute should happen + src_tensor = torch.randn((64, 1)) + dst_tensor = torch.zeros(16, 32, 64, 128) + src_specs = [[Shard(1)], [Shard(1)], [Shard(1)], [Shard(1)]] + dst_specs = [[Replicate()], [Shard(0)], [Shard(1)], [Shard(2)]] + for dst_spec, src_spec in zip(dst_specs, src_specs): + src_dtensor = distribute_tensor(src_tensor, device_mesh, src_spec) + dst_dtensor = distribute_tensor(dst_tensor, device_mesh, dst_spec) + dst_dtensor.copy_(src_dtensor) + dst_tensor.copy_(src_tensor) + self.assertEqual(dst_dtensor.full_tensor(), dst_tensor) + + @with_comms + def test_contiguous(self): + device_mesh = self.build_device_mesh() tensor = torch.rand(3, 5, 6, requires_grad=True) sharding = [Shard(0)] dist_tensor = DTensor.from_local(tensor, device_mesh, sharding) @@ -82,7 +121,7 @@ def test_contiguous(self): @with_comms def test_inplace_op(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() input_tensor = torch.randn((12, 3), device=self.device_type) dt_to_add = distribute_tensor(input_tensor, mesh, [Shard(0)]) dt_to_mul = dt_to_add.clone() @@ -109,7 +148,7 @@ def test_inplace_op(self): @with_comms def test_op_out_variant(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mesh = self.build_device_mesh() input_tensor = torch.randn((12, 3), device=self.device_type) sharded_dt_input = distribute_tensor(input_tensor, mesh, [Shard(0)]) expected_dt = sharded_dt_input.clone() + 3 @@ -130,7 +169,7 @@ def test_op_out_variant(self): @with_comms def test_empty_like(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -141,7 +180,7 @@ def test_empty_like(self): @with_comms def test_fill_inplace(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -153,7 +192,7 @@ def test_fill_inplace(self): @with_comms def test_full_like(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -164,7 +203,7 @@ def test_full_like(self): @with_comms def test_ones_like(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -175,7 +214,7 @@ def test_ones_like(self): @with_comms def test_ones_like_partial_sum(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Partial()] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -188,7 +227,7 @@ def test_ones_like_partial_sum(self): @with_comms def test_fill_inplace_partial_sum(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Partial()] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -204,7 +243,7 @@ def test_fill_inplace_partial_sum(self): @with_comms def test_zeros_like_partial_sum(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Partial()] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -217,7 +256,7 @@ def test_zeros_like_partial_sum(self): @with_comms def test_zero_inplace(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -229,7 +268,7 @@ def test_zero_inplace(self): @with_comms def test_zeros_like(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor = torch.randn(4, 8, requires_grad=True) @@ -281,7 +320,7 @@ def test_stack(self): @with_comms def test_equal(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] input_tensor_1 = torch.ones(4, 4) @@ -331,7 +370,7 @@ def _test_op(self, mesh, op_call, *args, **kwargs): @with_comms def test_new_full(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() comm_mode = CommDebugMode() global_tensor = torch.randn(12, 8) @@ -358,7 +397,7 @@ def test_new_full(self): @with_comms def test_new_empty_strided(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() comm_mode = CommDebugMode() shard_dim = 1 @@ -403,7 +442,7 @@ def test_new_empty_strided(self): @with_comms def test_scatter(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() comm_mode = CommDebugMode() # case 1 all replicate: input replicated, index/src replicated, output replicated @@ -437,7 +476,7 @@ def test_scatter(self): @with_comms def test_gather(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() comm_mode = CommDebugMode() # case 1 all replicate: input replicated, index replicated, output replicated @@ -488,7 +527,7 @@ def test_gather(self): @with_comms def test_index(self): meshes = [ - DeviceMesh(self.device_type, list(range(self.world_size))), # 1D mesh + self.build_device_mesh(), # 1D mesh # TODO(@azzolini): un-comment when DTensorConverter supports N-D mesh # DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, -1)), # 2D mesh ] @@ -638,7 +677,7 @@ def test_index_put_tensor(self): @with_comms def test_where_type_promotion(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) # 1D mesh + mesh = self.build_device_mesh() # 1D mesh specs = [[Shard(0)], [Replicate()]] for spec in specs: @@ -650,7 +689,7 @@ def test_where_type_promotion(self): @with_comms def test_dtensor_dtype_conversion(self): - device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + device_mesh = self.build_device_mesh() shard_spec = [Shard(0)] # by default we start from bf16 dtype local_tenor = torch.randn(2, 8, dtype=torch.bfloat16) @@ -684,7 +723,7 @@ def test_dtensor_dtype_conversion(self): @with_comms def test_slice(self): - mesh = DeviceMesh(self.device_type, list(range(self.world_size))) # 1D mesh + mesh = self.build_device_mesh() # 1D mesh comm_mode = CommDebugMode() shard_spec = [Shard(1)] @@ -722,27 +761,20 @@ def test_split_on_partial(self): def _test_split_on_partial(self, reduce_op: str, split_size: int, split_dim: int): torch.manual_seed(self.rank) - mesh = init_device_mesh(self.device_type, (self.world_size,)) + mesh = self.build_device_mesh() partial_tensor = torch.randn(8, 8, device=self.device_type) - replicate_tensor = partial_tensor.detach().clone() - replicate_tensor = funcol.all_reduce( - replicate_tensor, reduce_op, mesh - ) # all reduce to full tensor - replicate_tensor_list = replicate_tensor.split(split_size, dim=split_dim) - partial_dt = DTensor.from_local( local_tensor=partial_tensor, device_mesh=mesh, placements=[Partial(reduce_op=reduce_op)], ) - partial_dt_list = partial_dt.split(split_size, dim=split_dim) - - replicate_dt_full_tensor_list = [dt.full_tensor() for dt in partial_dt_list] - for replicate_tensor, replicate_dt_full_tensor in zip( - replicate_tensor_list, replicate_dt_full_tensor_list - ): - self.assertEqual(replicate_tensor, replicate_dt_full_tensor) + self._test_op_on_dtensor( + torch.split, + partial_dt, + split_size, + dim=split_dim, + ) if __name__ == "__main__": diff --git a/test/distributed/tensor/test_utils.py b/test/distributed/tensor/test_utils.py index 8629ca5261cf..dbfbac12223b 100644 --- a/test/distributed/tensor/test_utils.py +++ b/test/distributed/tensor/test_utils.py @@ -179,7 +179,7 @@ def test_compute_global_tensor_shape_1D_invalid_shape(self): ) with self.assertRaisesRegex( RuntimeError, - "Non-sharded dimentions should have identical size across ranks.", + "Non-sharded dimensions should have identical size across ranks.", ): _ = compute_global_tensor_shape( local_shape, diff --git a/test/distributed/test_c10d_pypg.py b/test/distributed/test_c10d_pypg.py index 6ccb81b116ca..65faf2075daa 100644 --- a/test/distributed/test_c10d_pypg.py +++ b/test/distributed/test_c10d_pypg.py @@ -181,6 +181,19 @@ def use_wrapper(self): return True +class BlockWork(dist._Work): + """ + Dummy work that is used to test blocking the current stream. + """ + + def __init__(self): + super().__init__() + self.future_ = torch.futures.Future() + + def get_future(self): + return self.future_ + + class TestPyProcessGroup(TestCase): def test_attr_overrides(self): pg = DummyAttrProcessGroup(0, 1) @@ -202,34 +215,61 @@ def test_abort_shutdown(self) -> None: @unittest.skipIf(not TEST_CUDA, "no cuda/xpu") def test_block_current_stream(self) -> None: - class BlockWork(dist._Work): - def __init__(self): - super().__init__() - self.future_ = torch.futures.Future() - - def get_future(self): - return self.future_ - - # nothing in queue so instantly resolves - event1 = torch.cuda.Event() - event1.record() - time.sleep(0.1) - self.assertTrue(event1.query()) - - work = BlockWork() - work.block_current_stream() - - # stream is blocked so doesn't resolve - event = torch.cuda.Event() - event.record() - time.sleep(0.1) - self.assertFalse(event.query()) - - # resolve the work - work.get_future().set_result(None) - - torch.cuda.current_stream().synchronize() - self.assertTrue(event.query()) + torch.cuda.synchronize() + + stream = torch.cuda.Stream() + with stream: + # nothing in queue so instantly resolves + event1 = torch.cuda.Event() + event1.record() + time.sleep(0.1) + self.assertTrue(event1.query()) + + work = BlockWork() + work.block_current_stream() + + # stream is blocked so doesn't resolve + event = torch.cuda.Event() + event.record() + time.sleep(0.1) + self.assertFalse(event.query()) + + # resolve the work + work.get_future().set_result(None) + + stream.synchronize() + self.assertTrue(event.query()) + + @unittest.skipIf(not TEST_CUDA, "no cuda/xpu") + def test_block_current_stream_use_after_free(self) -> None: + """ + This tests that the CPU control tensor is not freed before the CUDA kernel executes. + """ + torch.cuda.synchronize() + stream = torch.cuda.Stream() + with stream: + a = BlockWork() + a.block_current_stream() + + b = BlockWork() + b.block_current_stream() + + # unblock b first though a is still blocking + b.get_future().set_result(None) + # delete b + del b + + # a is still blocking so this doesn't resolve + event = torch.cuda.Event() + event.record() + time.sleep(0.1) + self.assertFalse(event.query()) + + # unblock a + a.get_future().set_result(None) + + stream.synchronize() + self.assertTrue(event.query()) if __name__ == "__main__": diff --git a/test/distributed/test_dist2.py b/test/distributed/test_dist2.py index 52ffd34e2a48..baaaf0550acd 100644 --- a/test/distributed/test_dist2.py +++ b/test/distributed/test_dist2.py @@ -5,6 +5,7 @@ from datetime import timedelta import torch +import torch.distributed as dist import torch.distributed._dist2 as dist2 from torch.testing._internal.common_distributed import ( MultiProcessTestCase, @@ -28,10 +29,14 @@ def test_context_manager(self): os.environ["MASTER_PORT"] = "29500" pg1 = dist2.new_group( - backend="gloo", timeout=timedelta(seconds=60), device="cpu", pg_options=None + backend="gloo", + timeout=timedelta(seconds=60), + device="cpu", ) pg2 = dist2.new_group( - backend="gloo", timeout=timedelta(seconds=60), device="cpu", pg_options=None + backend="gloo", + timeout=timedelta(seconds=60), + device="cpu", ) self.assertIsNone(dist2.current_process_group()) @@ -201,6 +206,50 @@ def test_alltoall_base(self) -> None: out_range = out[i * 10 : (i + 1) * 10] self.assertEqual(out_range, torch.full_like(out_range, i + 1)) + def test_group_split(self) -> None: + group = self.new_group() + subgroup = group.split_group([0], timeout=timedelta(seconds=30)) + if self.rank == 0: + assert subgroup is not None + self.assertEqual(subgroup.size(), 1) + backend = subgroup._get_backend(self.device) + self.assertEqual(backend.options._timeout, timedelta(seconds=30)) + else: + self.assertEqual(subgroup, None) + + def test_remote_group_merge(self) -> None: + group = self.new_group() + subgroup_1 = group.split_group([0], timeout=timedelta(seconds=30)) + subgroup_2 = group.split_group([1], timeout=timedelta(seconds=30)) + if self.rank == 0: + assert subgroup_1 is not None + tcp_store = dist.TCPStore( + host_name=os.environ["MASTER_ADDR"], + port=29781, + world_size=2, + is_master=True, + ) + merged_pg = subgroup_1.merge_remote_group( + tcp_store, 2, timedelta(seconds=40), "merged_pg" + ) + self.assertEqual(merged_pg.size(), 2) + backend = merged_pg._get_backend(self.device) + self.assertEqual(backend.options._timeout, timedelta(seconds=40)) + else: + assert subgroup_2 is not None + tcp_store = dist.TCPStore( + host_name=os.environ["MASTER_ADDR"], + port=29781, + world_size=2, + is_master=False, + ) + merged_pg = subgroup_2.merge_remote_group( + tcp_store, 2, timedelta(seconds=40), "merged_pg" + ) + self.assertEqual(merged_pg.size(), 2) + backend = merged_pg._get_backend(self.device) + self.assertEqual(backend.options._timeout, timedelta(seconds=40)) + class ProcessGroupGlooTest(Dist2MultiProcessTestCase): device = torch.device("cpu") @@ -216,7 +265,6 @@ def new_group(self) -> torch.distributed.ProcessGroup: backend="gloo", timeout=timedelta(seconds=60), device=self.device, - pg_options=None, ) @@ -231,15 +279,10 @@ def new_group(self) -> torch.distributed.ProcessGroup: self.device = torch.device("cuda", self.rank) - from torch.distributed import ProcessGroupNCCL - - opts = ProcessGroupNCCL.Options() - return dist2.new_group( backend="nccl", timeout=timedelta(seconds=60), device=self.device, - pg_options=opts, ) diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index d3436bbe4754..86410d8919d2 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -678,6 +678,88 @@ def test_fsdp_aot_eager(self): outputs = fsdp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) + @config.patch(enable_compiler_collectives=True) + @skip_if_lt_x_gpu(1) + def test_fsdp_dynamism_on_int_attr(self): + global GUARDS_FILE + GUARDS_FILE = StringIO() + + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + + class ToyModelWithIntAttr(nn.Module): + def __init__(self): + super().__init__() + self.attr = 2 + + def forward(self, x): + out = x + self.attr + + @comptime + def _(ctx): + ctx.print_guards(file=GUARDS_FILE) + + return out + + def get_model_with_int_attr(device): + m = ToyModelWithIntAttr().to(device) + inputs = torch.rand(10).to(device) + outputs = m(inputs) + return m, inputs, outputs + + m, inputs, correct_outputs = get_model_with_int_attr(f"cuda:{self.rank}") + fsdp_m = FSDP(m, use_orig_params=True) + compiled_fsdp_m = torch.compile( + fsdp_m, backend="eager", dynamic=True, fullgraph=True + ) + outputs = compiled_fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + FileCheck().check( + """local_fsdp_module "L['fn']._modules['_fsdp_wrapped_module'].attr" EQUALS_MATCH""" + ).run(GUARDS_FILE.getvalue()) + + @config.patch(enable_compiler_collectives=True) + @config.patch(allow_unspec_int_on_fsdp_module=True) + @skip_if_lt_x_gpu(1) + def test_fsdp_dynamism_on_int_attr_unspec(self): + global GUARDS_FILE + GUARDS_FILE = StringIO() + + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + + class ToyModelWithIntAttr(nn.Module): + def __init__(self): + super().__init__() + self.attr = 2 + + def forward(self, x): + out = x + self.attr + + @comptime + def _(ctx): + ctx.print_guards(file=GUARDS_FILE) + + return out + + def get_model_with_int_attr(device): + m = ToyModelWithIntAttr().to(device) + inputs = torch.rand(10).to(device) + outputs = m(inputs) + return m, inputs, outputs + + m, inputs, correct_outputs = get_model_with_int_attr(f"cuda:{self.rank}") + fsdp_m = FSDP(m, use_orig_params=True) + compiled_fsdp_m = torch.compile( + fsdp_m, backend="eager", dynamic=True, fullgraph=True + ) + outputs = compiled_fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + # No presence of EQUALS_MATCH because the guard will be dynamic + FileCheck().check( + """local_fsdp_module "L['fn']._modules['_fsdp_wrapped_module'].attr" TYPE_MATCH""" + ).run(GUARDS_FILE.getvalue()) + @skip_if_lt_x_gpu(2) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_ddp_optimizer_cudagraph(self): diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index fad2f8195600..1f09d72ea2b1 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -19,6 +19,7 @@ from torch._inductor.comms import ( _reorder_communication_preserving_peak_memory_internal, ReorderInfo, + sink_waits_iterative, ) from torch._inductor.compile_fx import compile_fx as inductor_compile_fx from torch._inductor.scheduler import BaseSchedulerNode @@ -1621,7 +1622,7 @@ def test_reorder_peak_memory_bucketed(self): comm from moving due to data dependency. """ - def func(x, w, ag_0, ag_1, *, tag, ranks, group_size): + def func(x, w, ag_0, ag_1, ag_2, ag_3, *, tag, ranks, group_size): # do some unrelated matmuls y = torch.mm(x, w) @@ -1654,14 +1655,52 @@ def func(x, w, ag_0, ag_1, *, tag, ranks, group_size): # wait op rs_0_out = torch.ops.c10d_functional.wait_tensor(rs_0_out) rs_1_out = torch.ops.c10d_functional.wait_tensor(rs_1_out) + y += torch.mm(2 * x, 2 * w) + + # cast the inputs + ag_2_cast = ag_2.to(torch.bfloat16) + ag_3_cast = ag_3.to(torch.bfloat16) + ag_2_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_2_cast, group_size, group_name + ) + ag_3_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_3_cast, group_size, group_name + ) + + # wait op + ag_2_out = torch.ops.c10d_functional.wait_tensor(ag_2_out) + ag_3_out = torch.ops.c10d_functional.wait_tensor(ag_3_out) + + # + rs_2_out = torch.ops._c10d_functional.reduce_scatter_tensor( + ag_2_cast, "sum", group_size, group_name + ) + rs_3_out = torch.ops._c10d_functional.reduce_scatter_tensor( + ag_3_cast, "sum", group_size, group_name + ) - return y, ag_0_out, ag_1_out, rs_0_out, rs_1_out + # wait op + rs_2_out = torch.ops.c10d_functional.wait_tensor(rs_2_out) + rs_3_out = torch.ops.c10d_functional.wait_tensor(rs_3_out) + return ( + y, + ag_0_out, + ag_1_out, + ag_2_out, + ag_3_out, + rs_0_out, + rs_1_out, + rs_2_out, + rs_3_out, + ) x = torch.ones(4, 384, device="cuda", dtype=torch.float32) w = torch.ones(384, 512, device="cuda", dtype=torch.float32) - ag_0 = torch.ones(384, 512, device="cuda", dtype=torch.float32) - ag_1 = torch.ones(512, device="cuda", dtype=torch.float32) - inputs = [x, w, ag_0, ag_1] + ag_0 = torch.ones(1024, 512, device="cuda", dtype=torch.float32) + ag_1 = torch.ones(512, 1024, device="cuda", dtype=torch.float32) + ag_2 = torch.ones(1024, 512, device="cuda", dtype=torch.float32) + ag_3 = torch.ones(512, 1024, device="cuda", dtype=torch.float32) + inputs = [x, w, ag_0, ag_1, ag_2, ag_3] # get stats directly from the internal helper without affecting the real pass's signature node_stats: Optional[dict[BaseSchedulerNode, ReorderInfo]] = None @@ -1679,11 +1718,15 @@ def _reorder_communication_preserving_peak_memory( with torch._inductor.config.patch( { "bucket_all_gathers_fx": "all", + "bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2, "bucket_reduce_scatters_fx": "all", + "bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2, "reorder_for_compute_comm_overlap": True, "reorder_for_compute_comm_overlap_passes": [ + sink_waits_iterative, _reorder_communication_preserving_peak_memory, ], + "allow_buffer_reuse": False, } ): compiled = torch.compile(func) @@ -1694,30 +1737,29 @@ def _reorder_communication_preserving_peak_memory( FileCheck() .check_count( "torch.ops._c10d_functional.all_gather_into_tensor_out.default(", - count=1, + count=2, exactly=True, ) + .check( + "extern_kernels.mm", + ) + .check( + "extern_kernels.addmm", + ) .run(code) ) ( FileCheck() .check_count( "torch.ops._c10d_functional.reduce_scatter_tensor.default(", - count=1, + count=2, exactly=True, ) - .run(code) - ) - ( - FileCheck() - .check( - "torch.ops._c10d_functional.all_gather_into_tensor_out.default(", - ) .check( - "torch.ops._c10d_functional.reduce_scatter_tensor.default(", + "extern_kernels.mm", ) .check( - "extern_kernels.mm", + "extern_kernels.addmm", ) .run(code) ) @@ -1726,7 +1768,7 @@ def _reorder_communication_preserving_peak_memory( assert same(out, correct), f"{out} va {correct}" assert node_stats is not None self.assertTrue(isinstance(node_stats, dict)) - self.assertEqual(len(node_stats), 2) + self.assertEqual(len(node_stats), 4) it = iter(node_stats.values()) node_stat0 = next(it) self.assertTrue(node_stat0.moves > 0) diff --git a/test/distributed/test_multi_threaded_pg.py b/test/distributed/test_multi_threaded_pg.py index 196cebb1617c..7ca6d25ad1c9 100644 --- a/test/distributed/test_multi_threaded_pg.py +++ b/test/distributed/test_multi_threaded_pg.py @@ -25,6 +25,8 @@ from torch.testing._internal.common_utils import IS_SANDCASTLE, run_tests, TestCase +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" + DEFAULT_WORLD_SIZE = 4 @@ -330,7 +332,7 @@ def backward(ctx, grad_output): return grad_output * result x = torch.tensor( - [dist.get_rank()], dtype=torch.float, device="cuda", requires_grad=True + [dist.get_rank()], dtype=torch.float, device=device_type, requires_grad=True ) x = MyFunc.apply(x) x.sum().backward() diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py index 2aabf9242784..c4565a96496c 100644 --- a/test/distributed/test_nvshmem_triton.py +++ b/test/distributed/test_nvshmem_triton.py @@ -4,6 +4,8 @@ # python test/distributed/test_nvshmem_triton.py +import triton.language as tl + import torch import torch.distributed as dist import torch.distributed._symmetric_memory as symm_mem @@ -13,7 +15,6 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, run_tests, - skip_but_pass_in_sandcastle, skip_but_pass_in_sandcastle_if, skipIfRocm, ) @@ -149,6 +150,87 @@ def put_with_quiet_kernel( nvshmem.putmem_block(flag_dst_ptr, flag_src_ptr, 1, peer) +@triton.jit +def barrier_test_kernel( + dst_ptr, + src_ptr, + numel, +): + # Testing barrier_all() requires coordinated operations across PEs within + # the same kernel execution. Unlike other kernels that just wrap NVSHMEM + # primitives, this one implements the full test logic to properly verify + # device-side barrier synchronization. + my_pe = nvshmem.my_pe() + n_pes = nvshmem.n_pes() + # Rank 0 broadcasts its value to all other ranks + if my_pe == 0: + # Write initial value + p_src = src_ptr.to(tl.pointer_type(tl.int32)) + tl.store(p_src, 42) + # Put to all other ranks + i = 1 + while i < n_pes: + nvshmem.putmem_block(dst_ptr, src_ptr, numel, i) + i += 1 + # Synchronize all PEs + nvshmem.barrier_all() + # Non-zero ranks increment the received value + if my_pe != 0: + p_dst = dst_ptr.to(tl.pointer_type(tl.int32)) + received = tl.load(p_dst) + tl.store(p_dst, received + 1) + + +@triton.jit +def sync_test_kernel( + dst_ptr, + src_ptr, + numel, +): + my_pe = nvshmem.my_pe() + n_pes = nvshmem.n_pes() + + # Rank 0 broadcasts its value to all other ranks + if my_pe == 0: + # Write initial value + p_src = src_ptr.to(tl.pointer_type(tl.int32)) + tl.store(p_src, 42) + # Put to all other ranks + i = 1 + while i < n_pes: + nvshmem.putmem_block(dst_ptr, src_ptr, numel, i) + i += 1 + # Synchronize all PEs (this is more lightweight than barrier_all() b/c it only ensures local store visibility + # and doesn't wait for remote ops to complete) + nvshmem.sync_all() + # Non-zero ranks increment the received value + if my_pe != 0: + p_dst = dst_ptr.to(tl.pointer_type(tl.int32)) + received = tl.load(p_dst) + tl.store(p_dst, received + 1) + + +@triton.jit +def alltoall_kernel( + team_handle, + dest_ptr, + src_ptr, + nelems, +): + nvshmem.alltoall(team_handle, dest_ptr, src_ptr, nelems) + + +@triton.jit +def broadcast_kernel( + team_handle, + dest_ptr, + src_ptr, + nelems, + pe_root, +): + nvshmem.broadcast(team_handle, dest_ptr, src_ptr, nelems, pe_root) + + @instantiate_parametrized_tests @requires_nvshmem() class NVSHMEMTritonTest(MultiProcContinousTest): @@ -173,7 +255,7 @@ def test_triton_put(self) -> None: # Enable NVSHMEM for Triton nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank @@ -187,7 +269,7 @@ def test_triton_put(self) -> None: inp_hdl = symm_mem.rendezvous(inp, group=group_name) out_hdl = symm_mem.rendezvous(out, group=group_name) - peer = 1 - rank + peer = (self.world_size - 1) - rank if rank == 0: dst_ptr = out_hdl.buffer_ptrs[rank] src_ptr = inp_hdl.buffer_ptrs[rank] @@ -212,7 +294,7 @@ def test_triton_get(self) -> None: self._init_device() nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank msg_size_bytes = 8 @@ -226,7 +308,7 @@ def test_triton_get(self) -> None: inp_hdl = symm_mem.rendezvous(inp, group=group_name) out_hdl = symm_mem.rendezvous(out, group=group_name) dist.barrier() - peer = 1 - rank + peer = (self.world_size - 1) - rank if rank == 1: # Rank 1 gets data from rank 0 dst_ptr = out_hdl.buffer_ptrs[rank] @@ -250,7 +332,7 @@ def test_triton_get_ring(self) -> None: self._init_device() nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank world_size = dist.get_world_size() @@ -293,7 +375,7 @@ def test_triton_put_signal_set(self) -> None: nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank @@ -312,7 +394,7 @@ def test_triton_put_signal_set(self) -> None: # as the flag buffer for signaling completion. flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0) - peer = 1 - rank + peer = (self.world_size - 1) - rank NVSHMEM_SIGNAL_SET = 0 # value defined by NVSHMEM for atomic set SIGNAL_VAL = 1 # Signal completion value NVSHMEM_CMP_EQ = 0 # compare equal for signal wait until @@ -358,7 +440,7 @@ def test_triton_put_signal_add(self) -> None: nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank @@ -377,7 +459,7 @@ def test_triton_put_signal_add(self) -> None: # as the flag buffer for signaling completion. flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0) - peer = 1 - rank + peer = (self.world_size - 1) - rank NVSHMEM_SIGNAL_ADD = 5 # atomic add operation SIGNAL_VAL = 16 # val + NVSHMEM_SIGNAL_ADD NVSHMEM_CMP_EQ = 0 @@ -413,50 +495,54 @@ def test_triton_put_signal_add(self) -> None: flag, torch.tensor([SIGNAL_VAL], dtype=torch.int64, device=self.device) ) - # This test hangs. TODO: investigate why. - @skip_but_pass_in_sandcastle("Hangs") @skipIfRocm @requires_triton() def test_triton_wait_until(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() + nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + peer = (self.world_size - 1) - rank + NVSHMEM_CMP_EQ = 0 # from nvshmem.h - # Data buffers + # Allocate symmetric buffers msg_size_bytes = 8 dtype = torch.int8 numel = msg_size_bytes // dtype.itemsize val = 13 flag_val = 21 + inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val) out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1) + inp_hdl = symm_mem.rendezvous(inp, group=group_name) out_hdl = symm_mem.rendezvous(out, group=group_name) - peer = 1 - rank - NVSHMEM_CMP_EQ = 0 # from nvshmem.h - NVSHMEM_SIGNAL_SET = 0 # atomic set operation - if rank == 0: # Rank 0 waits for the flag to be set by Rank 1, then checks the data ivar_ptr = out_hdl.signal_pad_ptrs[rank] + wait_until_kernel[(1, 1, 1)]( ivar_ptr, cmp_op=NVSHMEM_CMP_EQ, cmp_val=flag_val, extern_libs=nvshmem_lib, ) + torch.testing.assert_close( - out, val * torch.ones(numel, dtype=dtype, device=self.device) + out, + val * torch.ones(numel, dtype=dtype, device=self.device), ) if rank == 1: # Rank 1 puts data into Rank 0's output buffer - dst_ptr = out_hdl.buffer_ptrs[rank] + dst_ptr = out_hdl.buffer_ptrs[peer] src_ptr = inp_hdl.buffer_ptrs[rank] + put_kernel[(1, 1, 1)]( dst_ptr, src_ptr, @@ -465,12 +551,21 @@ def test_triton_wait_until(self) -> None: extern_libs=nvshmem_lib, ) - # Rank 1 sets the flag on Rank 0 using nvshmemx_signal_op - sig_addr = out_hdl.signal_pad_ptrs[rank] - signal_op_kernel[(1, 1, 1)]( - sig_addr, - signal=flag_val, - sig_op=NVSHMEM_SIGNAL_SET, + # Fence to order data put before flag put + @triton.jit + def fence_kernel(): + nvshmem.fence() + + fence_kernel[(1, 1, 1)](extern_libs=nvshmem_lib) + + # Put the flag value (do not use signal_op here) + flag_src = torch.tensor([flag_val], dtype=torch.int64, device=self.device) + flag_dst_ptr = out_hdl.signal_pad_ptrs[peer] + + put_kernel[(1, 1, 1)]( + flag_dst_ptr, + flag_src.data_ptr(), + numel=1, peer=peer, extern_libs=nvshmem_lib, ) @@ -481,10 +576,10 @@ def test_triton_signal_wait_until(self) -> None: self._init_device() # Enable NVSHMEM for Triton nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank - peer = 1 - rank + peer = (self.world_size - 1) - rank # NVSHMEM constants from documentation NVSHMEM_CMP_EQ = 0 # equal comparison @@ -557,10 +652,10 @@ def test_triton_fence(self) -> None: torch.manual_seed(42 + self.rank) self._init_device() nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank - peer = 1 - rank + peer = (self.world_size - 1) - rank # Message configuration msg_size_bytes = 8 dtype = torch.int8 @@ -632,7 +727,7 @@ def test_triton_quiet(self) -> None: self._init_device() # Enable NVSHMEM for Triton nvshmem_lib = nvshmem.enable_triton() - group_name = dist.group.WORLD.group_name + group_name = dist.distributed_c10d._get_default_group().group_name symm_mem.enable_symm_mem_for_group(group_name) rank = self.rank msg_size_bytes = 8 @@ -646,7 +741,7 @@ def test_triton_quiet(self) -> None: out_hdl = symm_mem.rendezvous(out, group=group_name) # Use signal pad as completion flag flag_val = 42 - peer = 1 - rank + peer = (self.world_size - 1) - rank NVSHMEM_CMP_EQ = 0 if rank == 0: @@ -682,6 +777,176 @@ def test_triton_quiet(self) -> None: extern_libs=nvshmem_lib, ) + @skipIfRocm + @requires_triton() + def test_triton_barrier(self) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.distributed_c10d._get_default_group().group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + numel = 1 + dtype = torch.int32 + # Create symmetric buffers + src = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) + dst = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) + src_hdl = symm_mem.rendezvous(src, group=group_name) + dst_hdl = symm_mem.rendezvous(dst, group=group_name) + # Launch kernel with cooperative grid + barrier_test_kernel[(1,)]( + dst_hdl.buffer_ptrs[rank], + src_hdl.buffer_ptrs[rank], + numel=numel, + extern_libs=nvshmem_lib, + launch_cooperative_grid=True, + num_ctas=1, + ) + # Verify results + # Rank 0 should have 42, and then the rest should have incremented + 1 to 43 + if rank == 0: + # Rank 0 should have its original value (42) in src + torch.testing.assert_close( + src, torch.tensor([42], device=self.device, dtype=dtype) + ) + else: + # Other ranks should have received 42 and incremented to 43 + torch.testing.assert_close( + dst, torch.tensor([43], device=self.device, dtype=dtype) + ) + + @skipIfRocm + @requires_triton() + def test_triton_sync(self) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + numel = 1 + dtype = torch.int32 + # Create symmetric buffers + src = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) + dst = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(0) + src_hdl = symm_mem.rendezvous(src, group=group_name) + dst_hdl = symm_mem.rendezvous(dst, group=group_name) + # Launch kernel with cooperative grid + sync_test_kernel[(1,)]( + dst_hdl.buffer_ptrs[rank], + src_hdl.buffer_ptrs[rank], + numel=numel, + extern_libs=nvshmem_lib, + launch_cooperative_grid=True, + num_ctas=1, + ) + # Verify results + if rank == 0: + # Rank 0 should have its original value (42) in src + torch.testing.assert_close( + src, torch.tensor([42], device=self.device, dtype=dtype) + ) + else: + # Other ranks should have received 42 and incremented to 43 + torch.testing.assert_close( + dst, torch.tensor([43], device=self.device, dtype=dtype) + ) + + @skipIfRocm + @requires_triton() + def test_triton_alltoall(self) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + world_size = dist.get_world_size() + rank = self.rank + # Each PE will send 2 int64 elements to every other PE + nelems_per_pe = 2 + dtype = torch.int64 + # Source buffer: contains data for all PEs + # Layout: [data_for_pe0, data_for_pe1, ...] + src_size = nelems_per_pe * world_size + src = symm_mem.empty(src_size, dtype=dtype, device=self.device) + # Fill source with rank-specific data + # Formula: rank * 100 + destination_pe + for i in range(world_size): + value = rank * 100 + i + src[i * nelems_per_pe : (i + 1) * nelems_per_pe] = value + # Destination buffer + dst = symm_mem.empty(src_size, dtype=dtype, device=self.device).fill_(-1) + src_hdl = symm_mem.rendezvous(src, group=group_name) + dst_hdl = symm_mem.rendezvous(dst, group=group_name) + # Synchronize before alltoall + dist.barrier() + team_handle = 0 # NVSHMEM_TEAM_WORLD handle is 0 + # Launch the kernel + alltoall_kernel[(1,)]( + team_handle, + dst_hdl.buffer_ptrs[rank], + src_hdl.buffer_ptrs[rank], + nelems_per_pe, + extern_libs=nvshmem_lib, + launch_cooperative_grid=True, + ) + # Synchronize after alltoall + dist.barrier() + # Verify results + for i in range(world_size): + # After alltoall, we should receive data from PE i that was intended for us + # PE i sends (i * 100 + rank) to us + expected = i * 100 + rank + actual = dst[i * nelems_per_pe : (i + 1) * nelems_per_pe] + torch.testing.assert_close(actual, torch.full_like(actual, expected)) + + @skipIfRocm + @requires_triton() + def test_triton_broadcast(self) -> None: + torch.manual_seed(42 + self.rank) + self._init_device() + nvshmem_lib = nvshmem.enable_triton() + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + rank = self.rank + # Configuration + nelems = 4 # number of elements + dtype = torch.int64 + # Source buffer - only root will have meaningful data + pe_root = 0 # PE 0 will be the root + src = symm_mem.empty(nelems, dtype=dtype, device=self.device) + if rank == pe_root: + # Root fills with specific pattern + for i in range(nelems): + src[i] = 100 + i + else: + # Non-root PEs have dummy data + src.fill_(-1) + # Destination buffer + dst = symm_mem.empty(nelems, dtype=dtype, device=self.device).fill_(-999) + src_hdl = symm_mem.rendezvous(src, group=group_name) + dst_hdl = symm_mem.rendezvous(dst, group=group_name) + # Synchronize before broadcast + dist.barrier() + # Execute broadcast + team_handle = 0 # NVSHMEM_TEAM_WORLD + broadcast_kernel[(1,)]( + team_handle, + dst_hdl.buffer_ptrs[rank], + src_hdl.buffer_ptrs[rank], + nelems, + pe_root, + extern_libs=nvshmem_lib, + launch_cooperative_grid=True, + ) + # Synchronize after broadcast + dist.barrier() + # Verify results - all ranks should have the root's data + expected = [100 + i for i in range(nelems)] + torch.testing.assert_close( + dst, torch.tensor(expected, device=self.device, dtype=dtype) + ) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py index f6f7fcfc3885..ed39107a0676 100644 --- a/test/distributed/test_symmetric_memory.py +++ b/test/distributed/test_symmetric_memory.py @@ -1065,10 +1065,6 @@ class SymmMemSingleProcTest(TestCase): not TEST_WITH_ROCM and _get_torch_cuda_version() < (12, 0), "stream_write_value32 currently only supports cuda version>=12.0", ) - @skipIf( - _get_torch_cuda_version() >= (12, 6), - "https://github.com/pytorch/pytorch/issues/154073", - ) @runOnRocmArch(MI300_ARCH) def test_stream_write_value32(self): tensor = torch.zeros(4, dtype=torch.uint32, device="cuda") diff --git a/test/dynamo/cpython/3_13/list_tests.diff b/test/dynamo/cpython/3_13/list_tests.diff index 903895b384b5..7889011f375d 100644 --- a/test/dynamo/cpython/3_13/list_tests.diff +++ b/test/dynamo/cpython/3_13/list_tests.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/list_tests.py b/test/dynamo/cpython/3_13/list_tests.py -index dbc5ef4f9f2..239b75f74cc 100644 +index dbc5ef4f9f2..70e24036f74 100644 --- a/test/dynamo/cpython/3_13/list_tests.py +++ b/test/dynamo/cpython/3_13/list_tests.py -@@ -1,3 +1,53 @@ +@@ -1,3 +1,56 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/list_tests.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -56,7 +59,7 @@ index dbc5ef4f9f2..239b75f74cc 100644 """ Tests common to list and UserList.UserList """ -@@ -5,7 +55,7 @@ Tests common to list and UserList.UserList +@@ -5,7 +58,7 @@ Tests common to list and UserList.UserList import sys from functools import cmp_to_key @@ -65,7 +68,7 @@ index dbc5ef4f9f2..239b75f74cc 100644 from test.support import ALWAYS_EQ, NEVER_EQ, get_c_recursion_limit -@@ -119,10 +169,6 @@ class CommonTest(seq_tests.CommonTest): +@@ -119,10 +172,6 @@ class CommonTest(seq_tests.CommonTest): a[-1] = 9 self.assertEqual(a, self.type2test([5,6,7,8,9])) diff --git a/test/dynamo/cpython/3_13/list_tests.py b/test/dynamo/cpython/3_13/list_tests.py index 239b75f74cc4..70e24036f74d 100644 --- a/test/dynamo/cpython/3_13/list_tests.py +++ b/test/dynamo/cpython/3_13/list_tests.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/list_tests.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/mapping_tests.diff b/test/dynamo/cpython/3_13/mapping_tests.diff index 03ae75513d66..009b53f31b55 100644 --- a/test/dynamo/cpython/3_13/mapping_tests.diff +++ b/test/dynamo/cpython/3_13/mapping_tests.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/mapping_tests.py b/test/dynamo/cpython/3_13/mapping_tests.py -index ed89a81a6ea..eed59a68e94 100644 +index ed89a81a6ea..10fc6e7e467 100644 --- a/test/dynamo/cpython/3_13/mapping_tests.py +++ b/test/dynamo/cpython/3_13/mapping_tests.py -@@ -1,10 +1,61 @@ +@@ -1,10 +1,64 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/mapping_tests.py ++ +import sys +import torch +import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/mapping_tests.py b/test/dynamo/cpython/3_13/mapping_tests.py index eed59a68e944..10fc6e7e4672 100644 --- a/test/dynamo/cpython/3_13/mapping_tests.py +++ b/test/dynamo/cpython/3_13/mapping_tests.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/mapping_tests.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/seq_tests.diff b/test/dynamo/cpython/3_13/seq_tests.diff index 03c7021e4f96..b87c26ece27c 100644 --- a/test/dynamo/cpython/3_13/seq_tests.diff +++ b/test/dynamo/cpython/3_13/seq_tests.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/seq_tests.py b/test/dynamo/cpython/3_13/seq_tests.py -index 719c9434a16..4325892276d 100644 +index 719c9434a16..2c502cda4f6 100644 --- a/test/dynamo/cpython/3_13/seq_tests.py +++ b/test/dynamo/cpython/3_13/seq_tests.py -@@ -1,3 +1,54 @@ +@@ -1,3 +1,57 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/seq_tests.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -57,7 +60,7 @@ index 719c9434a16..4325892276d 100644 """ Tests common to tuple, list and UserList.UserList """ -@@ -95,7 +146,7 @@ class LyingList(list): +@@ -95,7 +149,7 @@ class LyingList(list): def __iter__(self): yield 1 diff --git a/test/dynamo/cpython/3_13/seq_tests.py b/test/dynamo/cpython/3_13/seq_tests.py index 4325892276d4..2c502cda4f61 100644 --- a/test/dynamo/cpython/3_13/seq_tests.py +++ b/test/dynamo/cpython/3_13/seq_tests.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/seq_tests.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_baseexception.diff b/test/dynamo/cpython/3_13/test_baseexception.diff index b25d72d0f65d..240e4e554d6a 100644 --- a/test/dynamo/cpython/3_13/test_baseexception.diff +++ b/test/dynamo/cpython/3_13/test_baseexception.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_baseexception.py b/test/dynamo/cpython/3_13/test_baseexception.py -index e599b02c17d..3dc102e3b8a 100644 +index e599b02c17d..750d7a84fb4 100644 --- a/test/dynamo/cpython/3_13/test_baseexception.py +++ b/test/dynamo/cpython/3_13/test_baseexception.py -@@ -1,10 +1,61 @@ +@@ -1,10 +1,64 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_baseexception.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -65,7 +68,7 @@ index e599b02c17d..3dc102e3b8a 100644 """Tests for anything relating to exception objects themselves (e.g., inheritance hierarchy)""" -@@ -78,9 +129,6 @@ class ExceptionClassTests(unittest.TestCase): +@@ -78,9 +132,6 @@ class ExceptionClassTests(unittest.TestCase): last_depth = depth finally: inheritance_tree.close() @@ -75,7 +78,7 @@ index e599b02c17d..3dc102e3b8a 100644 self.assertEqual(len(exc_set), 0, "%s not accounted for" % exc_set) interface_tests = ("length", "args", "str", "repr") -@@ -142,7 +190,7 @@ class ExceptionClassTests(unittest.TestCase): +@@ -142,7 +193,7 @@ class ExceptionClassTests(unittest.TestCase): gc.collect() @@ -84,7 +87,7 @@ index e599b02c17d..3dc102e3b8a 100644 """Test usage of exceptions""" -@@ -208,5 +256,5 @@ class UsageTests(unittest.TestCase): +@@ -208,5 +259,5 @@ class UsageTests(unittest.TestCase): self.catch_fails("spam") diff --git a/test/dynamo/cpython/3_13/test_baseexception.py b/test/dynamo/cpython/3_13/test_baseexception.py index 3dc102e3b8a2..750d7a84fb45 100644 --- a/test/dynamo/cpython/3_13/test_baseexception.py +++ b/test/dynamo/cpython/3_13/test_baseexception.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_baseexception.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_cmath.diff b/test/dynamo/cpython/3_13/test_cmath.diff index 7157e8c0498f..c229add52902 100644 --- a/test/dynamo/cpython/3_13/test_cmath.diff +++ b/test/dynamo/cpython/3_13/test_cmath.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_cmath.py b/test/dynamo/cpython/3_13/test_cmath.py -index a96a5780b31..883e87a0733 100644 +index a96a5780b31..37fb665d97d 100644 --- a/test/dynamo/cpython/3_13/test_cmath.py +++ b/test/dynamo/cpython/3_13/test_cmath.py -@@ -1,5 +1,55 @@ +@@ -1,5 +1,58 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_cmath.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -59,7 +62,7 @@ index a96a5780b31..883e87a0733 100644 from test.test_math import parse_testfile, test_file import test.test_math as test_math import unittest -@@ -50,7 +100,7 @@ complex_nans = [complex(x, y) for x, y in [ +@@ -50,7 +103,7 @@ complex_nans = [complex(x, y) for x, y in [ (INF, NAN) ]] @@ -68,7 +71,7 @@ index a96a5780b31..883e87a0733 100644 # list of all functions in cmath test_functions = [getattr(cmath, fname) for fname in [ 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', -@@ -66,6 +116,39 @@ class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase): +@@ -66,6 +119,39 @@ class CMathTests(ComplexesAreIdenticalMixin, unittest.TestCase): def tearDown(self): self.test_values.close() @@ -108,7 +111,7 @@ index a96a5780b31..883e87a0733 100644 def rAssertAlmostEqual(self, a, b, rel_err = 2e-15, abs_err = 5e-323, msg=None): """Fail if the two floating-point numbers are not almost equal. -@@ -590,4 +673,4 @@ class IsCloseTests(test_math.IsCloseTests): +@@ -590,4 +676,4 @@ class IsCloseTests(test_math.IsCloseTests): if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_cmath.py b/test/dynamo/cpython/3_13/test_cmath.py index 883e87a07337..37fb665d97d2 100644 --- a/test/dynamo/cpython/3_13/test_cmath.py +++ b/test/dynamo/cpython/3_13/test_cmath.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_cmath.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_complex.diff b/test/dynamo/cpython/3_13/test_complex.diff index a7867e47f227..57a2d4315f21 100644 --- a/test/dynamo/cpython/3_13/test_complex.diff +++ b/test/dynamo/cpython/3_13/test_complex.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_complex.py b/test/dynamo/cpython/3_13/test_complex.py -index 6ff1a8ab29d..ab5bd3dab62 100644 +index 6ff1a8ab29d..cda348d2f37 100644 --- a/test/dynamo/cpython/3_13/test_complex.py +++ b/test/dynamo/cpython/3_13/test_complex.py -@@ -1,16 +1,143 @@ +@@ -1,16 +1,146 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_complex.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -151,7 +154,7 @@ index 6ff1a8ab29d..ab5bd3dab62 100644 INF = float("inf") NAN = float("nan") DBL_MAX = sys.float_info.max -@@ -45,7 +172,40 @@ class WithComplex: +@@ -45,7 +175,40 @@ class WithComplex: def __complex__(self): return self.value @@ -193,7 +196,7 @@ index 6ff1a8ab29d..ab5bd3dab62 100644 def assertAlmostEqual(self, a, b): if isinstance(a, complex): -@@ -74,6 +234,29 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): +@@ -74,6 +237,29 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): # check that relative difference < eps self.assertTrue(abs((x-y)/y) < eps) @@ -223,7 +226,7 @@ index 6ff1a8ab29d..ab5bd3dab62 100644 def assertClose(self, x, y, eps=1e-9): """Return true iff complexes x and y "are close".""" self.assertCloseAbs(x.real, y.real, eps) -@@ -855,4 +1038,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): +@@ -855,4 +1041,4 @@ class ComplexTest(ComplexesAreIdenticalMixin, unittest.TestCase): if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_complex.py b/test/dynamo/cpython/3_13/test_complex.py index ab5bd3dab62b..cda348d2f377 100644 --- a/test/dynamo/cpython/3_13/test_complex.py +++ b/test/dynamo/cpython/3_13/test_complex.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_complex.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_contextlib.diff b/test/dynamo/cpython/3_13/test_contextlib.diff index f3314f590c10..3850f6696681 100644 --- a/test/dynamo/cpython/3_13/test_contextlib.diff +++ b/test/dynamo/cpython/3_13/test_contextlib.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_contextlib.py b/test/dynamo/cpython/3_13/test_contextlib.py -index cf651959803..6a17bc719eb 100644 +index cf651959803..51fd083b112 100644 --- a/test/dynamo/cpython/3_13/test_contextlib.py +++ b/test/dynamo/cpython/3_13/test_contextlib.py -@@ -1,3 +1,54 @@ +@@ -1,3 +1,57 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_contextlib.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -57,7 +60,7 @@ index cf651959803..6a17bc719eb 100644 """Unit tests for contextlib.py, and other context managers.""" import io -@@ -14,7 +65,7 @@ from test.support.testcase import ExceptionIsLikeMixin +@@ -14,7 +68,7 @@ from test.support.testcase import ExceptionIsLikeMixin import weakref @@ -66,7 +69,7 @@ index cf651959803..6a17bc719eb 100644 def test_enter(self): class DefaultEnter(AbstractContextManager): -@@ -67,7 +118,7 @@ class TestAbstractContextManager(unittest.TestCase): +@@ -67,7 +121,7 @@ class TestAbstractContextManager(unittest.TestCase): self.assertFalse(issubclass(NoExit, AbstractContextManager)) @@ -75,7 +78,7 @@ index cf651959803..6a17bc719eb 100644 def test_contextmanager_plain(self): state = [] -@@ -396,7 +447,7 @@ def woohoo(): +@@ -396,7 +450,7 @@ def woohoo(): self.assertEqual(depth, 0) @@ -84,7 +87,7 @@ index cf651959803..6a17bc719eb 100644 @support.requires_docstrings def test_instance_docs(self): -@@ -430,7 +481,7 @@ class ClosingTestCase(unittest.TestCase): +@@ -430,7 +484,7 @@ class ClosingTestCase(unittest.TestCase): self.assertEqual(state, [1]) @@ -93,7 +96,7 @@ index cf651959803..6a17bc719eb 100644 def test_nullcontext(self): class C: pass -@@ -439,7 +490,7 @@ class NullcontextTestCase(unittest.TestCase): +@@ -439,7 +493,7 @@ class NullcontextTestCase(unittest.TestCase): self.assertIs(c_in, c) @@ -102,7 +105,7 @@ index cf651959803..6a17bc719eb 100644 def testWithOpen(self): tfn = tempfile.mktemp() -@@ -457,7 +508,7 @@ class FileContextTestCase(unittest.TestCase): +@@ -457,7 +511,7 @@ class FileContextTestCase(unittest.TestCase): finally: os_helper.unlink(tfn) @@ -111,7 +114,7 @@ index cf651959803..6a17bc719eb 100644 def boilerPlate(self, lock, locked): self.assertFalse(locked()) -@@ -520,7 +571,7 @@ class mycontext(ContextDecorator): +@@ -520,7 +574,7 @@ class mycontext(ContextDecorator): return self.catch @@ -120,7 +123,7 @@ index cf651959803..6a17bc719eb 100644 @support.requires_docstrings def test_instance_docs(self): -@@ -680,7 +731,7 @@ class TestContextDecorator(unittest.TestCase): +@@ -680,7 +734,7 @@ class TestContextDecorator(unittest.TestCase): self.assertEqual(state, [1, 'something else', 999]) @@ -129,7 +132,7 @@ index cf651959803..6a17bc719eb 100644 exit_stack = None @support.requires_docstrings -@@ -1141,7 +1192,7 @@ class TestBaseExitStack: +@@ -1141,7 +1195,7 @@ class TestBaseExitStack: self.assertIs(exc.__cause__, exc.__context__) @@ -138,7 +141,7 @@ index cf651959803..6a17bc719eb 100644 exit_stack = ExitStack callback_error_internal_frames = [ ('__exit__', 'raise exc'), -@@ -1149,7 +1200,7 @@ class TestExitStack(TestBaseExitStack, unittest.TestCase): +@@ -1149,7 +1203,7 @@ class TestExitStack(TestBaseExitStack, unittest.TestCase): ] @@ -147,7 +150,7 @@ index cf651959803..6a17bc719eb 100644 redirect_stream = None orig_stream = None -@@ -1206,19 +1257,19 @@ class TestRedirectStream: +@@ -1206,19 +1260,19 @@ class TestRedirectStream: self.assertEqual(s, "Hello World!\n") @@ -170,7 +173,7 @@ index cf651959803..6a17bc719eb 100644 @support.requires_docstrings def test_instance_docs(self): -@@ -1315,7 +1366,7 @@ class TestSuppress(ExceptionIsLikeMixin, unittest.TestCase): +@@ -1315,7 +1369,7 @@ class TestSuppress(ExceptionIsLikeMixin, unittest.TestCase): ) @@ -179,7 +182,7 @@ index cf651959803..6a17bc719eb 100644 def make_relative_path(self, *parts): return os.path.join( os.path.dirname(os.path.realpath(__file__)), -@@ -1331,6 +1382,7 @@ class TestChdir(unittest.TestCase): +@@ -1331,6 +1385,7 @@ class TestChdir(unittest.TestCase): self.assertEqual(os.getcwd(), target) self.assertEqual(os.getcwd(), old_cwd) @@ -187,7 +190,7 @@ index cf651959803..6a17bc719eb 100644 def test_reentrant(self): old_cwd = os.getcwd() target1 = self.make_relative_path('data') -@@ -1363,4 +1415,4 @@ class TestChdir(unittest.TestCase): +@@ -1363,4 +1418,4 @@ class TestChdir(unittest.TestCase): if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_contextlib.py b/test/dynamo/cpython/3_13/test_contextlib.py index 6a17bc719eb9..51fd083b1129 100644 --- a/test/dynamo/cpython/3_13/test_contextlib.py +++ b/test/dynamo/cpython/3_13/test_contextlib.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_contextlib.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_dict.diff b/test/dynamo/cpython/3_13/test_dict.diff index 9589bcf797bd..0c6beec66dad 100644 --- a/test/dynamo/cpython/3_13/test_dict.diff +++ b/test/dynamo/cpython/3_13/test_dict.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_dict.py b/test/dynamo/cpython/3_13/test_dict.py -index 4729132c5a5..14f829c1715 100644 +index 4c095464cbb..fcda6484ea6 100644 --- a/test/dynamo/cpython/3_13/test_dict.py +++ b/test/dynamo/cpython/3_13/test_dict.py -@@ -1,3 +1,57 @@ +@@ -1,3 +1,60 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_dict.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -60,7 +63,7 @@ index 4729132c5a5..14f829c1715 100644 import collections import collections.abc import gc -@@ -11,7 +65,7 @@ from test import support +@@ -11,7 +68,7 @@ from test import support from test.support import import_helper, get_c_recursion_limit @@ -69,15 +72,48 @@ index 4729132c5a5..14f829c1715 100644 def test_invalid_keyword_arguments(self): class Custom(dict): -@@ -265,6 +319,7 @@ class DictTest(unittest.TestCase): +@@ -265,39 +322,7 @@ class DictTest(unittest.TestCase): self.assertRaises(ValueError, {}.update, [(1, 2, 3)]) +- def test_update_shared_keys(self): +- class MyClass: pass +- +- # Subclass str to enable us to create an object during the +- # dict.update() call. +- class MyStr(str): +- def __hash__(self): +- return super().__hash__() +- +- def __eq__(self, other): +- # Create an object that shares the same PyDictKeysObject as +- # obj.__dict__. +- obj2 = MyClass() +- obj2.a = "a" +- obj2.b = "b" +- obj2.c = "c" +- return super().__eq__(other) +- +- obj = MyClass() +- obj.a = "a" +- obj.b = "b" +- +- x = {} +- x[MyStr("a")] = MyStr("a") +- +- # gh-132617: this previously raised "dict mutated during update" error +- x.update(obj.__dict__) +- +- self.assertEqual(x, { +- MyStr("a"): "a", +- "b": "b", +- }) +- + @unittest.skip("test hangs") def test_fromkeys(self): self.assertEqual(dict.fromkeys('abc'), {'a':None, 'b':None, 'c':None}) d = {} -@@ -477,7 +532,7 @@ class DictTest(unittest.TestCase): +@@ -510,7 +535,7 @@ class DictTest(unittest.TestCase): for copymode in -1, +1: # -1: b has same structure as a # +1: b is a.copy() @@ -86,7 +122,7 @@ index 4729132c5a5..14f829c1715 100644 size = 2**log2size a = {} b = {} -@@ -1006,18 +1061,6 @@ class DictTest(unittest.TestCase): +@@ -1039,18 +1064,6 @@ class DictTest(unittest.TestCase): pass self._tracked(MyDict()) @@ -105,7 +141,7 @@ index 4729132c5a5..14f829c1715 100644 def make_shared_key_dict(self, n): class C: pass -@@ -1622,7 +1665,7 @@ class DictTest(unittest.TestCase): +@@ -1655,7 +1668,7 @@ class DictTest(unittest.TestCase): self.assertGreaterEqual(eq_count, 1) @@ -114,7 +150,7 @@ index 4729132c5a5..14f829c1715 100644 # Test _PyDict_GetItem_KnownHash() @support.cpython_only -@@ -1666,4 +1709,4 @@ class SubclassMappingTests(mapping_tests.BasicTestMappingProtocol): +@@ -1699,4 +1712,4 @@ class SubclassMappingTests(mapping_tests.BasicTestMappingProtocol): if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_dict.py b/test/dynamo/cpython/3_13/test_dict.py index 14f829c1715c..fcda6484ea60 100644 --- a/test/dynamo/cpython/3_13/test_dict.py +++ b/test/dynamo/cpython/3_13/test_dict.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_dict.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_exception_variations.diff b/test/dynamo/cpython/3_13/test_exception_variations.diff index 45424e087b5a..52ae731d9493 100644 --- a/test/dynamo/cpython/3_13/test_exception_variations.diff +++ b/test/dynamo/cpython/3_13/test_exception_variations.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_exception_variations.py b/test/dynamo/cpython/3_13/test_exception_variations.py -index a83a41d2975..be432089e3a 100644 +index a83a41d2975..c2d6eb3a41a 100644 --- a/test/dynamo/cpython/3_13/test_exception_variations.py +++ b/test/dynamo/cpython/3_13/test_exception_variations.py -@@ -1,7 +1,59 @@ +@@ -1,7 +1,62 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_exception_variations.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -53,17 +56,17 @@ index a83a41d2975..be432089e3a 100644 +# Add the custom finder to sys.meta_path +sys.meta_path.insert(0, RedirectImportFinder()) + -+ -+# ======= END DYNAMO PATCH ======= -class ExceptTestCases(unittest.TestCase): ++# ======= END DYNAMO PATCH ======= ++ +import unittest + +class ExceptTestCases(__TestCase): def test_try_except_else_finally(self): hit_except = False hit_else = False -@@ -294,282 +346,5 @@ class ExceptTestCases(unittest.TestCase): +@@ -294,282 +349,5 @@ class ExceptTestCases(unittest.TestCase): self.assertTrue(hit_except) diff --git a/test/dynamo/cpython/3_13/test_exception_variations.py b/test/dynamo/cpython/3_13/test_exception_variations.py index be432089e3a3..c2d6eb3a41af 100644 --- a/test/dynamo/cpython/3_13/test_exception_variations.py +++ b/test/dynamo/cpython/3_13/test_exception_variations.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_exception_variations.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_exceptions.diff b/test/dynamo/cpython/3_13/test_exceptions.diff index e69de29bb2d1..6dcc9c858a9f 100644 --- a/test/dynamo/cpython/3_13/test_exceptions.diff +++ b/test/dynamo/cpython/3_13/test_exceptions.diff @@ -0,0 +1,152 @@ +diff --git a/test/dynamo/cpython/3_13/test_exceptions.py b/test/dynamo/cpython/3_13/test_exceptions.py +index c91f6662948..0ded70db3c7 100644 +--- a/test/dynamo/cpython/3_13/test_exceptions.py ++++ b/test/dynamo/cpython/3_13/test_exceptions.py +@@ -1,3 +1,59 @@ ++# ======= BEGIN Dynamo patch ======= ++# Owner(s): ["module: dynamo"] ++ ++# ruff: noqa ++# flake8: noqa ++ ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_exceptions.py ++ ++import sys ++import torch ++import torch._dynamo.test_case ++import unittest ++from torch._dynamo.test_case import CPythonTestCase ++from torch.testing._internal.common_utils import ( ++ run_tests, ++ xfailIfTorchDynamo, ++) ++ ++__TestCase = CPythonTestCase ++ ++# redirect import statements ++import sys ++import importlib.abc ++ ++redirect_imports = ( ++ "test.mapping_tests", ++ "test.typinganndata", ++ "test.test_grammar", ++ "test.test_math", ++ "test.test_iter", ++ "test.typinganndata.ann_module", ++) ++ ++class RedirectImportFinder(importlib.abc.MetaPathFinder): ++ def find_spec(self, fullname, path, target=None): ++ # Check if the import is the problematic one ++ if fullname in redirect_imports: ++ try: ++ # Attempt to import the standalone module ++ name = fullname.removeprefix("test.") ++ r = importlib.import_module(name) ++ # Redirect the module in sys.modules ++ sys.modules[fullname] = r ++ # Return a module spec from the found module ++ return importlib.util.find_spec(name) ++ except ImportError: ++ return None ++ return None ++ ++# Add the custom finder to sys.meta_path ++sys.meta_path.insert(0, RedirectImportFinder()) ++ ++ ++# ======= END DYNAMO PATCH ======= ++ + # Python test set -- part 5, built-in exceptions + + import copy +@@ -45,7 +101,7 @@ class BrokenStrException(Exception): + # XXX This is not really enough, each *operation* should be tested! + + +-class ExceptionTests(unittest.TestCase): ++class ExceptionTests(__TestCase): + + def raise_catch(self, exc, excname): + with self.subTest(exc=exc, excname=excname): +@@ -1844,7 +1900,7 @@ class ExceptionTests(unittest.TestCase): + self.assertIn(b'MemoryError', err) + + +-class NameErrorTests(unittest.TestCase): ++class NameErrorTests(__TestCase): + def test_name_error_has_name(self): + try: + bluch +@@ -1894,7 +1950,7 @@ class NameErrorTests(unittest.TestCase): + # Note: name suggestion tests live in `test_traceback`. + + +-class AttributeErrorTests(unittest.TestCase): ++class AttributeErrorTests(__TestCase): + def test_attributes(self): + # Setting 'attr' should not be a problem. + exc = AttributeError('Ouch!') +@@ -1937,7 +1993,7 @@ class AttributeErrorTests(unittest.TestCase): + # Note: name suggestion tests live in `test_traceback`. + + +-class ImportErrorTests(unittest.TestCase): ++class ImportErrorTests(__TestCase): + + def test_attributes(self): + # Setting 'name' and 'path' should not be a problem. +@@ -2024,7 +2080,7 @@ def run_script(source): + _rc, _out, err = script_helper.assert_python_failure('-Wd', '-X', 'utf8', TESTFN) + return err.decode('utf-8').splitlines() + +-class AssertionErrorTests(unittest.TestCase): ++class AssertionErrorTests(__TestCase): + def tearDown(self): + unlink(TESTFN) + +@@ -2159,7 +2215,7 @@ class AssertionErrorTests(unittest.TestCase): + + + @support.force_not_colorized_test_class +-class SyntaxErrorTests(unittest.TestCase): ++class SyntaxErrorTests(__TestCase): + maxDiff = None + + @force_not_colorized +@@ -2290,6 +2346,7 @@ class SyntaxErrorTests(unittest.TestCase): + err = run_script(b"\x89") + self.assertIn("SyntaxError: Non-UTF-8 code starting with '\\x89' in file", err[-1]) + ++ + def test_string_source(self): + def try_compile(source): + with self.assertRaises(SyntaxError) as cm: +@@ -2405,7 +2462,7 @@ class SyntaxErrorTests(unittest.TestCase): + self.assertRaises(TypeError, SyntaxError, "bad bad", args) + + +-class TestInvalidExceptionMatcher(unittest.TestCase): ++class TestInvalidExceptionMatcher(__TestCase): + def test_except_star_invalid_exception_type(self): + with self.assertRaises(TypeError): + try: +@@ -2420,7 +2477,7 @@ class TestInvalidExceptionMatcher(unittest.TestCase): + pass + + +-class PEP626Tests(unittest.TestCase): ++class PEP626Tests(__TestCase): + + def lineno_after_raise(self, f, *expected): + try: +@@ -2529,5 +2586,5 @@ class PEP626Tests(unittest.TestCase): + 1/0 + self.lineno_after_raise(after_with, 1, 1) + +-if __name__ == '__main__': +- unittest.main() ++if __name__ == "__main__": ++ run_tests() diff --git a/test/dynamo/cpython/3_13/test_exceptions.py b/test/dynamo/cpython/3_13/test_exceptions.py index e6a9a2676bc0..0ded70db3c78 100644 --- a/test/dynamo/cpython/3_13/test_exceptions.py +++ b/test/dynamo/cpython/3_13/test_exceptions.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_exceptions.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_float.diff b/test/dynamo/cpython/3_13/test_float.diff index 6b8586b1c663..73cd65364fbc 100644 --- a/test/dynamo/cpython/3_13/test_float.diff +++ b/test/dynamo/cpython/3_13/test_float.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_float.py b/test/dynamo/cpython/3_13/test_float.py -index 97f951f1299..ce2c46777e0 100644 +index 87af79eb446..9313a1a63d7 100644 --- a/test/dynamo/cpython/3_13/test_float.py +++ b/test/dynamo/cpython/3_13/test_float.py -@@ -1,3 +1,54 @@ +@@ -1,3 +1,57 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_float.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -57,7 +60,7 @@ index 97f951f1299..ce2c46777e0 100644 import fractions import operator import os -@@ -8,11 +59,84 @@ import time +@@ -8,11 +62,84 @@ import time import unittest from test import support @@ -147,7 +150,7 @@ index 97f951f1299..ce2c46777e0 100644 from math import isinf, isnan, copysign, ldexp import math -@@ -35,7 +159,7 @@ class FloatSubclass(float): +@@ -35,7 +162,7 @@ class FloatSubclass(float): class OtherFloatSubclass(float): pass @@ -156,7 +159,7 @@ index 97f951f1299..ce2c46777e0 100644 def test_float(self): self.assertEqual(float(3.14), 3.14) -@@ -620,7 +744,7 @@ class GeneralFloatCases(unittest.TestCase): +@@ -620,7 +747,7 @@ class GeneralFloatCases(unittest.TestCase): @unittest.skipUnless(hasattr(float, "__getformat__"), "requires __getformat__") @@ -165,7 +168,7 @@ index 97f951f1299..ce2c46777e0 100644 def test_getformat(self): self.assertIn(float.__getformat__('double'), ['unknown', 'IEEE, big-endian', 'IEEE, little-endian']) -@@ -645,7 +769,7 @@ LE_FLOAT_NAN = bytes(reversed(BE_FLOAT_NAN)) +@@ -645,7 +772,7 @@ LE_FLOAT_NAN = bytes(reversed(BE_FLOAT_NAN)) # is accident (today). # let's also try to guarantee that -0.0 and 0.0 don't get confused. @@ -174,7 +177,7 @@ index 97f951f1299..ce2c46777e0 100644 @support.requires_IEEE_754 def test_double_specials_do_unpack(self): -@@ -670,7 +794,7 @@ class IEEEFormatTestCase(unittest.TestCase): +@@ -670,7 +797,7 @@ class IEEEFormatTestCase(unittest.TestCase): self.assertEqual(struct.pack("=": "issuperset", -@@ -1334,22 +1402,22 @@ class TestSubsets: +@@ -1334,22 +1405,22 @@ class TestSubsets: result = eval("x" + case + "y", locals()) self.assertEqual(result, expected) # Test the "friendly" method-name spelling, if one exists. @@ -321,7 +324,7 @@ index d9102eb98a5..0b8e99a04c4 100644 left = set() right = set() name = "both empty" -@@ -1357,7 +1425,7 @@ class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase): +@@ -1357,7 +1428,7 @@ class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase): #------------------------------------------------------------------------------ @@ -330,7 +333,7 @@ index d9102eb98a5..0b8e99a04c4 100644 left = set([1, 2]) right = set([1, 2]) name = "equal pair" -@@ -1365,7 +1433,7 @@ class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase): +@@ -1365,7 +1436,7 @@ class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase): #------------------------------------------------------------------------------ @@ -339,7 +342,7 @@ index d9102eb98a5..0b8e99a04c4 100644 left = set() right = set([1, 2]) name = "one empty, one non-empty" -@@ -1373,7 +1441,7 @@ class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase): +@@ -1373,7 +1444,7 @@ class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase): #------------------------------------------------------------------------------ @@ -348,7 +351,7 @@ index d9102eb98a5..0b8e99a04c4 100644 left = set([1]) right = set([1, 2]) name = "one a non-empty proper subset of other" -@@ -1381,7 +1449,7 @@ class TestSubsetPartial(TestSubsets, unittest.TestCase): +@@ -1381,7 +1452,7 @@ class TestSubsetPartial(TestSubsets, unittest.TestCase): #------------------------------------------------------------------------------ @@ -357,7 +360,7 @@ index d9102eb98a5..0b8e99a04c4 100644 left = set([1]) right = set([2]) name = "neither empty, neither contains" -@@ -1389,7 +1457,7 @@ class TestSubsetNonOverlap(TestSubsets, unittest.TestCase): +@@ -1389,7 +1460,7 @@ class TestSubsetNonOverlap(TestSubsets, unittest.TestCase): #============================================================================== @@ -366,7 +369,7 @@ index d9102eb98a5..0b8e99a04c4 100644 def test_eq_ne(self): # Unlike the others, this is testing that == and != *are* allowed. -@@ -1505,47 +1573,52 @@ class TestOnlySetsInBinaryOps: +@@ -1505,47 +1576,52 @@ class TestOnlySetsInBinaryOps: #------------------------------------------------------------------------------ @@ -425,7 +428,7 @@ index d9102eb98a5..0b8e99a04c4 100644 def setUp(self): def gen(): for i in range(0, 10, 2): -@@ -1553,10 +1626,11 @@ class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase): +@@ -1553,10 +1629,11 @@ class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase): self.set = set((1, 2, 3)) self.other = gen() self.otherIsIterable = True @@ -438,7 +441,7 @@ index d9102eb98a5..0b8e99a04c4 100644 def test_copy(self): dup = self.set.copy() -@@ -1577,40 +1651,46 @@ class TestCopying: +@@ -1577,40 +1654,46 @@ class TestCopying: #------------------------------------------------------------------------------ @@ -491,7 +494,7 @@ index d9102eb98a5..0b8e99a04c4 100644 def test_binopsVsSubsets(self): a, b = self.a, self.b -@@ -1727,7 +1807,7 @@ def L(seqn): +@@ -1727,7 +1810,7 @@ def L(seqn): 'Test multiple tiers of iterators' return chain(map(lambda x:x, R(Ig(G(seqn))))) @@ -500,7 +503,7 @@ index d9102eb98a5..0b8e99a04c4 100644 def test_constructor(self): for cons in (set, frozenset): -@@ -1785,7 +1865,7 @@ class bad_dict_clear: +@@ -1785,7 +1868,7 @@ class bad_dict_clear: def __hash__(self): return 0 @@ -509,7 +512,7 @@ index d9102eb98a5..0b8e99a04c4 100644 def test_8420_set_merge(self): # This used to segfault global be_bad, set2, dict2 -@@ -1826,7 +1906,7 @@ class TestWeirdBugs(unittest.TestCase): +@@ -1826,7 +1909,7 @@ class TestWeirdBugs(unittest.TestCase): s.update(other) @@ -518,7 +521,7 @@ index d9102eb98a5..0b8e99a04c4 100644 """Regression test for bpo-46615""" constructor1 = None -@@ -1862,7 +1942,7 @@ class TestOperationsMutating: +@@ -1862,7 +1945,7 @@ class TestOperationsMutating: self.assertIn("changed size during iteration", str(e)) @@ -527,7 +530,7 @@ index d9102eb98a5..0b8e99a04c4 100644 def test_eq_with_mutation(self): self.check_set_op_does_not_crash(lambda a, b: a == b) -@@ -1933,24 +2013,24 @@ class TestBinaryOpsMutating(TestOperationsMutating): +@@ -1933,24 +2016,24 @@ class TestBinaryOpsMutating(TestOperationsMutating): self.check_set_op_does_not_crash(f3) @@ -557,7 +560,7 @@ index d9102eb98a5..0b8e99a04c4 100644 def test_issubset_with_mutation(self): self.check_set_op_does_not_crash(set.issubset) -@@ -1986,27 +2066,27 @@ class TestMethodsMutating(TestOperationsMutating): +@@ -1986,27 +2069,27 @@ class TestMethodsMutating(TestOperationsMutating): self.check_set_op_does_not_crash(set.update) @@ -591,7 +594,7 @@ index d9102eb98a5..0b8e99a04c4 100644 constructor1 = set constructor2 = list -@@ -2068,7 +2148,7 @@ def faces(G): +@@ -2068,7 +2151,7 @@ def faces(G): return f @@ -600,7 +603,7 @@ index d9102eb98a5..0b8e99a04c4 100644 def test_cube(self): -@@ -2118,4 +2198,4 @@ class TestGraphs(unittest.TestCase): +@@ -2118,4 +2201,4 @@ class TestGraphs(unittest.TestCase): #============================================================================== if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_set.py b/test/dynamo/cpython/3_13/test_set.py index 0b8e99a04c45..3543d60751e3 100644 --- a/test/dynamo/cpython/3_13/test_set.py +++ b/test/dynamo/cpython/3_13/test_set.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_set.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_sort.diff b/test/dynamo/cpython/3_13/test_sort.diff index 78fde5ef19a1..9049f2853251 100644 --- a/test/dynamo/cpython/3_13/test_sort.diff +++ b/test/dynamo/cpython/3_13/test_sort.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_sort.py b/test/dynamo/cpython/3_13/test_sort.py -index 2a7cfb7affa..d661ae544b9 100644 +index 2a7cfb7affa..58b9b796362 100644 --- a/test/dynamo/cpython/3_13/test_sort.py +++ b/test/dynamo/cpython/3_13/test_sort.py -@@ -1,3 +1,54 @@ +@@ -1,3 +1,57 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_sort.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -57,7 +60,7 @@ index 2a7cfb7affa..d661ae544b9 100644 from test import support import random import unittest -@@ -39,7 +90,7 @@ def check(tag, expected, raw, compare=None): +@@ -39,7 +93,7 @@ def check(tag, expected, raw, compare=None): nerrors += 1 return @@ -66,7 +69,7 @@ index 2a7cfb7affa..d661ae544b9 100644 def testStressfully(self): # Try a variety of sizes at and around powers of 2, and at powers of 10. sizes = [0] -@@ -151,7 +202,7 @@ class TestBase(unittest.TestCase): +@@ -151,7 +205,7 @@ class TestBase(unittest.TestCase): self.assertEqual(forced, native) #============================================================================== @@ -75,7 +78,7 @@ index 2a7cfb7affa..d661ae544b9 100644 def test_bug453523(self): # bug 453523 -- list.sort() crasher. -@@ -188,7 +239,7 @@ class TestBugs(unittest.TestCase): +@@ -188,7 +242,7 @@ class TestBugs(unittest.TestCase): #============================================================================== @@ -84,7 +87,7 @@ index 2a7cfb7affa..d661ae544b9 100644 def test_decorated(self): data = 'The quick Brown fox Jumped over The lazy Dog'.split() -@@ -309,7 +360,7 @@ def check_against_PyObject_RichCompareBool(self, L): +@@ -309,7 +363,7 @@ def check_against_PyObject_RichCompareBool(self, L): self.assertIs(opt, ref) #note: not assertEqual! We want to ensure *identical* behavior. @@ -93,7 +96,7 @@ index 2a7cfb7affa..d661ae544b9 100644 def test_safe_object_compare(self): heterogeneous_lists = [[0, 'foo'], [0.0, 'foo'], -@@ -408,4 +459,4 @@ class TestOptimizedCompares(unittest.TestCase): +@@ -408,4 +462,4 @@ class TestOptimizedCompares(unittest.TestCase): #============================================================================== if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_sort.py b/test/dynamo/cpython/3_13/test_sort.py index d661ae544b99..58b9b7963622 100644 --- a/test/dynamo/cpython/3_13/test_sort.py +++ b/test/dynamo/cpython/3_13/test_sort.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_sort.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_sys.diff b/test/dynamo/cpython/3_13/test_sys.diff index 7fd024156056..1c0cc65b3663 100644 --- a/test/dynamo/cpython/3_13/test_sys.diff +++ b/test/dynamo/cpython/3_13/test_sys.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_sys.py b/test/dynamo/cpython/3_13/test_sys.py -index 72d51361e0b..0b4c6882e62 100644 +index 6b37094ed5f..c5e96a6a3dd 100644 --- a/test/dynamo/cpython/3_13/test_sys.py +++ b/test/dynamo/cpython/3_13/test_sys.py -@@ -1,3 +1,55 @@ +@@ -1,3 +1,58 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_sys.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -58,7 +61,7 @@ index 72d51361e0b..0b4c6882e62 100644 import builtins import codecs import _datetime -@@ -35,7 +87,7 @@ def requires_subinterpreters(meth): +@@ -35,7 +90,7 @@ def requires_subinterpreters(meth): DICT_KEY_STRUCT_FORMAT = 'n2BI2n' @@ -67,7 +70,7 @@ index 72d51361e0b..0b4c6882e62 100644 def test_original_displayhook(self): dh = sys.__displayhook__ -@@ -81,19 +133,8 @@ class DisplayHookTest(unittest.TestCase): +@@ -81,19 +136,8 @@ class DisplayHookTest(unittest.TestCase): code = compile("42", "", "single") self.assertRaises(ValueError, eval, code) @@ -77,18 +80,18 @@ index 72d51361e0b..0b4c6882e62 100644 - sys.stdout = io.StringIO() - support.gc_collect() - return 'foo' -- + - with support.swap_attr(sys, 'stdout', None): - sys.stdout = io.StringIO() # the only reference - sys.displayhook(X()) # should not crash - +- - -class ActiveExceptionTests(unittest.TestCase): +class ActiveExceptionTests(__TestCase): def test_exc_info_no_exception(self): self.assertEqual(sys.exc_info(), (None, None, None)) -@@ -157,7 +198,7 @@ class ActiveExceptionTests(unittest.TestCase): +@@ -157,7 +201,7 @@ class ActiveExceptionTests(unittest.TestCase): self.assertIs(exc, e) @@ -97,7 +100,7 @@ index 72d51361e0b..0b4c6882e62 100644 @force_not_colorized def test_original_excepthook(self): -@@ -200,7 +241,7 @@ class ExceptHookTest(unittest.TestCase): +@@ -200,7 +244,7 @@ class ExceptHookTest(unittest.TestCase): # Python/pythonrun.c::PyErr_PrintEx() is tricky. @@ -106,7 +109,7 @@ index 72d51361e0b..0b4c6882e62 100644 def tearDown(self): test.support.reap_children() -@@ -500,6 +541,7 @@ class SysModuleTest(unittest.TestCase): +@@ -500,6 +544,7 @@ class SysModuleTest(unittest.TestCase): is sys._getframe().f_code ) @@ -114,16 +117,21 @@ index 72d51361e0b..0b4c6882e62 100644 def test_getframemodulename(self): # Default depth gets ourselves self.assertEqual(__name__, sys._getframemodulename()) -@@ -808,7 +850,7 @@ class SysModuleTest(unittest.TestCase): - self.assertRaises(TypeError, sys.intern, S("abc")) - if has_is_interned: - self.assertIs(sys._is_interned(S("abc")), False) -- -+ - @support.cpython_only - @requires_subinterpreters - def test_subinterp_intern_dynamically_allocated(self): -@@ -1359,7 +1401,7 @@ class SysModuleTest(unittest.TestCase): +@@ -894,7 +939,12 @@ class SysModuleTest(unittest.TestCase): + def assert_raise_on_new_sys_type(self, sys_attr): + # Users are intentionally prevented from creating new instances of + # sys.flags, sys.version_info, and sys.getwindowsversion. +- support.check_disallow_instantiation(self, type(sys_attr), sys_attr) ++ arg = sys_attr ++ attr_type = type(sys_attr) ++ with self.assertRaises(TypeError): ++ attr_type(arg) ++ with self.assertRaises(TypeError): ++ attr_type.__new__(attr_type, arg) + + def test_sys_flags_no_instantiation(self): + self.assert_raise_on_new_sys_type(sys.flags) +@@ -1354,7 +1404,7 @@ class SysModuleTest(unittest.TestCase): @test.support.cpython_only @@ -132,7 +140,7 @@ index 72d51361e0b..0b4c6882e62 100644 def test_original_unraisablehook(self): _testcapi = import_helper.import_module('_testcapi') from _testcapi import err_writeunraisable, err_formatunraisable -@@ -1516,7 +1558,7 @@ class UnraisableHookTest(unittest.TestCase): +@@ -1511,7 +1561,7 @@ class UnraisableHookTest(unittest.TestCase): @test.support.cpython_only @@ -141,7 +149,7 @@ index 72d51361e0b..0b4c6882e62 100644 def setUp(self): self.P = struct.calcsize('P') -@@ -1524,6 +1566,7 @@ class SizeofTest(unittest.TestCase): +@@ -1519,6 +1569,7 @@ class SizeofTest(unittest.TestCase): _testinternalcapi = import_helper.import_module("_testinternalcapi") self.gc_headsize = _testinternalcapi.SIZEOF_PYGC_HEAD self.managed_pre_header_size = _testinternalcapi.SIZEOF_MANAGED_PRE_HEADER @@ -149,7 +157,7 @@ index 72d51361e0b..0b4c6882e62 100644 check_sizeof = test.support.check_sizeof -@@ -1960,4 +2003,4 @@ class SizeofTest(unittest.TestCase): +@@ -1955,4 +2006,4 @@ class SizeofTest(unittest.TestCase): self.assertEqual(err, b"") if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_sys.py b/test/dynamo/cpython/3_13/test_sys.py index f2d782127a48..c5e96a6a3ddd 100644 --- a/test/dynamo/cpython/3_13/test_sys.py +++ b/test/dynamo/cpython/3_13/test_sys.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_sys.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_tuple.diff b/test/dynamo/cpython/3_13/test_tuple.diff index 46d4bb32d9ef..6e792b6c5450 100644 --- a/test/dynamo/cpython/3_13/test_tuple.diff +++ b/test/dynamo/cpython/3_13/test_tuple.diff @@ -1,8 +1,8 @@ diff --git a/test/dynamo/cpython/3_13/test_tuple.py b/test/dynamo/cpython/3_13/test_tuple.py -index 9ce80c5e8ea..e52c0cbc140 100644 +index 9ce80c5e8ea..c6eab3ff1e9 100644 --- a/test/dynamo/cpython/3_13/test_tuple.py +++ b/test/dynamo/cpython/3_13/test_tuple.py -@@ -1,4 +1,55 @@ +@@ -1,4 +1,58 @@ -from test import support, seq_tests +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] @@ -10,6 +10,9 @@ index 9ce80c5e8ea..e52c0cbc140 100644 +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_tuple.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -59,7 +62,7 @@ index 9ce80c5e8ea..e52c0cbc140 100644 import unittest import gc -@@ -510,4 +561,4 @@ class TupleTest(seq_tests.CommonTest): +@@ -510,4 +564,4 @@ class TupleTest(seq_tests.CommonTest): # pileup 262,143 mean 8.0 coll 262,143 z +92683.6 if __name__ == "__main__": diff --git a/test/dynamo/cpython/3_13/test_tuple.py b/test/dynamo/cpython/3_13/test_tuple.py index e52c0cbc1403..c6eab3ff1e92 100644 --- a/test/dynamo/cpython/3_13/test_tuple.py +++ b/test/dynamo/cpython/3_13/test_tuple.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_tuple.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_userdict.diff b/test/dynamo/cpython/3_13/test_userdict.diff index 1c0157489206..8b8101ae9091 100644 --- a/test/dynamo/cpython/3_13/test_userdict.diff +++ b/test/dynamo/cpython/3_13/test_userdict.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_userdict.py b/test/dynamo/cpython/3_13/test_userdict.py -index 61e79f553e8..c953390355e 100644 +index 61e79f553e8..75b789633ed 100644 --- a/test/dynamo/cpython/3_13/test_userdict.py +++ b/test/dynamo/cpython/3_13/test_userdict.py -@@ -1,3 +1,54 @@ +@@ -1,3 +1,57 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_userdict.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -57,7 +60,7 @@ index 61e79f553e8..c953390355e 100644 # Check every path through every method of UserDict from test import mapping_tests, support -@@ -215,10 +266,10 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol): +@@ -215,10 +269,10 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol): # Decorate existing test with recursion limit, because # the test is for C structure, but `UserDict` is a Python structure. diff --git a/test/dynamo/cpython/3_13/test_userdict.py b/test/dynamo/cpython/3_13/test_userdict.py index c953390355e6..75b789633edf 100644 --- a/test/dynamo/cpython/3_13/test_userdict.py +++ b/test/dynamo/cpython/3_13/test_userdict.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_userdict.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/cpython/3_13/test_userlist.diff b/test/dynamo/cpython/3_13/test_userlist.diff index 299a8abeb99a..20999ba6bca0 100644 --- a/test/dynamo/cpython/3_13/test_userlist.diff +++ b/test/dynamo/cpython/3_13/test_userlist.diff @@ -1,14 +1,17 @@ diff --git a/test/dynamo/cpython/3_13/test_userlist.py b/test/dynamo/cpython/3_13/test_userlist.py -index 312702c8e39..a4532922f5d 100644 +index 312702c8e39..5ede0c3b7f1 100644 --- a/test/dynamo/cpython/3_13/test_userlist.py +++ b/test/dynamo/cpython/3_13/test_userlist.py -@@ -1,7 +1,58 @@ +@@ -1,7 +1,61 @@ +# ======= BEGIN Dynamo patch ======= +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + ++# Test copied from ++# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_userlist.py ++ +import sys +import torch +import torch._dynamo.test_case @@ -62,7 +65,7 @@ index 312702c8e39..a4532922f5d 100644 import unittest from test import support -@@ -69,9 +120,9 @@ class UserListTest(list_tests.CommonTest): +@@ -69,9 +123,9 @@ class UserListTest(list_tests.CommonTest): # Decorate existing test with recursion limit, because # the test is for C structure, but `UserList` is a Python structure. diff --git a/test/dynamo/cpython/3_13/test_userlist.py b/test/dynamo/cpython/3_13/test_userlist.py index a4532922f5d4..5ede0c3b7f1a 100644 --- a/test/dynamo/cpython/3_13/test_userlist.py +++ b/test/dynamo/cpython/3_13/test_userlist.py @@ -4,6 +4,9 @@ # ruff: noqa # flake8: noqa +# Test copied from +# https://raw.githubusercontent.com/python/cpython/refs/tags/v3.13.5/Lib/test/test_userlist.py + import sys import torch import torch._dynamo.test_case diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index af162b41ccd7..0de83cd2dc31 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -1213,7 +1213,7 @@ def fn(x): @torch._functorch.config.patch(donated_buffer=True) def test_donated_buffer1(self): - logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" + logger_name = "torch._functorch._aot_autograd.graph_compile" @torch.compile() def relu(x): @@ -1233,7 +1233,7 @@ def relu(x): @torch._functorch.config.patch("donated_buffer", True) def test_donated_buffer2(self): - logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" + logger_name = "torch._functorch._aot_autograd.graph_compile" # we will reuse the graph for g across f1 and f2 @torch.compile() @@ -1255,7 +1255,7 @@ def f(inp, param1, param2): @torch._functorch.config.patch("donated_buffer", True) def test_donated_buffer3(self): - logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" + logger_name = "torch._functorch._aot_autograd.graph_compile" # we will reuse the graph for g across f1 and f2 @torch.compile() @@ -1278,7 +1278,7 @@ def f(inp, param1, param2): @torch._functorch.config.patch("donated_buffer", True) def test_donated_buffer4(self): - logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" + logger_name = "torch._functorch._aot_autograd.graph_compile" class Mod(torch.nn.Module): def __init__(self) -> None: @@ -1309,7 +1309,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @torch._functorch.config.patch("donated_buffer", True) def test_donated_buffer5(self): - logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" + logger_name = "torch._functorch._aot_autograd.graph_compile" @torch.compile() def f(x, z): @@ -1346,7 +1346,7 @@ def test_donated_buffer6(self): # SymNodeVariable() is not a constant return - logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" + logger_name = "torch._functorch._aot_autograd.graph_compile" def fn(x): p = torch.nn.Parameter(x + 123) diff --git a/test/dynamo/test_backends.py b/test/dynamo/test_backends.py index 7c4402edeca6..9d61bbf31acb 100644 --- a/test/dynamo/test_backends.py +++ b/test/dynamo/test_backends.py @@ -8,7 +8,6 @@ import torch._dynamo.backends import torch._dynamo.test_case from torch._dynamo.backends.debugging import ExplainWithBackend -from torch._dynamo.backends.onnxrt import has_onnxruntime from torch._dynamo.backends.tvm import has_tvm from torch._dynamo.testing import same from torch.fx._lazy_graph_module import _force_skip_lazy_graph_module @@ -138,10 +137,6 @@ def test_aot_ts(self, device): def test_aot_cudagraphs(self, device): self._check_backend_works("cudagraphs", device) - @unittest.skipIf(not has_onnxruntime(), "requires onnxruntime") - def test_onnxrt(self, device): - self._check_backend_works("onnxrt", device) - @unittest.skipIf(not has_tvm(), "requires tvm") def test_tvm(self, device): self._check_backend_works("tvm", device) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 70e1946c3096..3b29e5e96119 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -514,6 +514,23 @@ def fn(x, s): fn(x, State(41)) self.assertEqual(cnts.frame_count, 2) + def test_nonstrict_trace_int_and_float_output(self): + @torch._dynamo.nonstrict_trace + def trace_me(x): + torch._dynamo.graph_break() + return len(x.shape), 0.42 + + def fn(x): + n1, n2 = trace_me(x) + return x * n1 + n2 + + x = torch.randn(10) + opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager") + + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + def test_nonstrict_trace_tuple_and_sym_int_output(self): @torch._dynamo.nonstrict_trace def trace_me(x): @@ -719,6 +736,34 @@ def fn(x, y): except torch._dynamo.exc.Unsupported as e: self.assertIn("Invalid input type for nonstrict_trace-ed function", str(e)) + def test_nonstrict_trace_custom_class_output_error(self): + class Point: + x: torch.Tensor + y: torch.Tensor + + def __init__(self, x, y): + self.x = x + self.y = y + + @torch._dynamo.nonstrict_trace + def trace_me(x): + torch._dynamo.graph_break() + return Point(x, x + 1) + + @torch.compile(fullgraph=True, backend="aot_eager") + def fn(x): + p = trace_me(x) + return p.x * p.y + + try: + x = torch.ones(10) + fn(x) + self.assertFalse(True) # must raise error before this + except torch._dynamo.exc.Unsupported as e: + self.assertIn( + "Unsupported output type for nonstrict_trace-ed function", str(e) + ) + def test_nonstrict_newly_constructed_trace_register_constant_type_error(self): class State: def __init__(self, n): diff --git a/test/dynamo/test_fx_graph_runnable.py b/test/dynamo/test_fx_graph_runnable.py index 74d17dd6825f..0164b6f9c680 100644 --- a/test/dynamo/test_fx_graph_runnable.py +++ b/test/dynamo/test_fx_graph_runnable.py @@ -50,6 +50,7 @@ def forward(self, x): return x +@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") class FxGraphRunnableTest(TestCase): def setUp(self): super().setUp() @@ -92,7 +93,6 @@ def _exec_and_verify_payload(self): ) # basic tests - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") def test_basic_tensor_add(self): def f(x): return x + 1 @@ -100,7 +100,6 @@ def f(x): torch.compile(f)(torch.randn(4)) self._exec_and_verify_payload() - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") def test_two_inputs_matmul(self): def f(a, b): return (a @ b).relu() @@ -109,7 +108,6 @@ def f(a, b): torch.compile(f)(a, b) self._exec_and_verify_payload() - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") def test_scalar_multiply(self): def f(x): return x * 2 @@ -118,7 +116,6 @@ def f(x): self._exec_and_verify_payload() # testing dynamic shapes - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") def test_dynamic_shapes_run(self): def f(x): return (x @ x.transpose(0, 1)).relu() @@ -130,7 +127,6 @@ def f(x): torch.compile(f)(a) self._exec_and_verify_payload() - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") def test_broadcast_add_dynamic(self): def f(x, y): return x + y * 2 @@ -143,7 +139,6 @@ def f(x, y): torch.compile(f)(x, y) self._exec_and_verify_payload() - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") def test_toy_model_basic(self): model = ToyModel(input_size=8, hidden_size=16, output_size=4) model.eval() # Set to eval mode to avoid dropout randomness @@ -152,7 +147,6 @@ def test_toy_model_basic(self): torch.compile(model)(x) self._exec_and_verify_payload() - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") def test_toy_model_batch_processing(self): model = ToyModel(input_size=12, hidden_size=24, output_size=6) model.eval() @@ -161,7 +155,6 @@ def test_toy_model_batch_processing(self): torch.compile(model)(x) self._exec_and_verify_payload() - @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") def test_toy_model_dynamic_batch(self): model = ToyModel(input_size=10, hidden_size=20, output_size=5) model.eval() diff --git a/test/dynamo/test_guard_manager.py b/test/dynamo/test_guard_manager.py index 1aeafaf5dd33..83d7cec8a7f1 100644 --- a/test/dynamo/test_guard_manager.py +++ b/test/dynamo/test_guard_manager.py @@ -811,7 +811,7 @@ def test_clone(self): except ImportError: from utils import install_guard_manager_testing_hook - def hook(guard_wrapper, f_locals): + def hook(guard_wrapper, f_locals, builder): root = guard_wrapper.root # Check full cloning works as expected @@ -851,7 +851,7 @@ def test_diff_guard_manager(self): from utils import install_guard_manager_testing_hook counter = 0 - def hook(guard_wrapper, f_locals): + def hook(guard_wrapper, f_locals, builder): nonlocal counter root = guard_wrapper.root diff_guard_root = guard_wrapper.diff_guard_root @@ -898,6 +898,65 @@ def fn(x, foo, bar): opt_fn(x, foo, bar) +class TypePropagationTests(torch._dynamo.test_case.TestCase): + @torch._dynamo.config.patch(skip_tensor_guards_with_matching_dict_tags=True) + def test_basic_types(self): + class Foo: + def __init__(self): + self.x = {"a": 2} + self.y = torch.randn(4) + self.z = {} + + foo = Foo() + + mod = torch.nn.Linear(4, 4) + + def fn(x): + return x + foo.x["a"] + foo.y + mod(x) + + try: + from .utils import install_guard_manager_testing_hook + except ImportError: + from utils import install_guard_manager_testing_hook + + def hook(guard_wrapper, f_locals, builder): + from torch._dynamo.source import AttrSource, DictGetItemSource, LocalSource + + foo_source = LocalSource("foo") + foo_x_source = AttrSource(foo_source, "x") + + self.assertTrue(builder.get(foo_source.name()) is foo) + self.assertTrue(builder.get(foo_x_source.name()) is foo.x) + + # Check types of foo.x + foo_x_mgr = builder.get_guard_manager_from_source(foo_x_source) + self.assertTrue(foo_x_mgr.is_guarded_value_dict()) + + # Check types of foo.x["a"] + foo_x_a_source = DictGetItemSource(foo_x_source, "a") + foo_x_a_mgr = builder.get_guard_manager_from_source(foo_x_a_source) + self.assertTrue(foo_x_a_mgr.is_guarded_value_immutable()) + + # Check types of foo.y + foo_y_source = AttrSource(foo_source, "y") + foo_y_mgr = builder.get_guard_manager_from_source(foo_y_source) + self.assertTrue(foo_y_mgr.is_guarded_value_immutable()) + + # Check types of foo.z + foo_z_source = AttrSource(foo_source, "z") + foo_z_mgr = builder.get_guard_manager_from_source(foo_z_source) + self.assertTrue(foo_z_mgr.is_guarded_value_empty_dict()) + + # Check types of mod + mod_source = LocalSource("mod") + mod_mgr = builder.get_guard_manager_from_source(mod_source) + self.assertTrue(mod_mgr.is_guarded_value_nn_module()) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + with install_guard_manager_testing_hook(hook): + opt_fn(torch.randn(4, 4)) + + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index 8e5f12894711..10808c922b3f 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -878,6 +878,23 @@ def fn(x): ): self._test_serialization("ID_MATCH", fn, torch.randn(3)) + @torch._dynamo.config.patch(caching_precompile=True) + def test_id_match_with_config(self): + def fn(x): + return x + id(x) + + ref, loaded = self._test_serialization("ID_MATCH", fn, torch.randn(3)) + self._test_check_fn(ref, loaded, {"x": torch.randn(3)}, True) + + def fn(x): + # usage of this context manager installs a FUNCTION_MATCH guard + with torch.no_grad(): + y = x * 2 + return y + + ref, loaded = self._test_serialization("FUNCTION_MATCH", fn, torch.randn(3)) + self._test_check_fn(ref, loaded, {"x": torch.randn(3)}, True) + def test_dispatch_key_set_match(self): def fn(x, dks): if dks.has("CPU"): diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index e7b6d426247c..b9c1ff3a61fe 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -1186,7 +1186,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): pred = a.sum() > 0 with self.assertRaisesRegex( NotImplementedError, - "no rule registered for HOP cond and mode .*MyMode", + "no rule registered for HigherOrderOperator cond and mode .*MyMode", ): with MyMode(): res = cond_op(pred, torch.sin, torch.cos, (a,)) @@ -7106,6 +7106,103 @@ def test_non_aliasing_util(self): ): _assert_tensors_nonaliasing(a, a) + def test_flop_counter_for_cond(self): + from torch.utils.flop_counter import FlopCounterMode + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + return torch.cond( + torch.tensor(True), + lambda x: self.linear(x), + lambda x: self.linear(self.linear(x)), + (x,), + ) + + mod = Mod() + with FlopCounterMode(mod, display=False) as mode: + mod(torch.randn(4, 4)) + + self.assertEqual( + mode.get_flop_counts(), + { + "Global": {torch.ops.aten.addmm: 256}, + "Mod": {torch.ops.aten.addmm: 256}, + "Mod.linear": {torch.ops.aten.addmm: 256}, + }, + ) + + def test_flop_counter_for_nested_cond(self): + from torch.utils.flop_counter import FlopCounterMode + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + + def forward(self, x): + def true_branch(x): + # Nested cond inside true branch + return torch.cond( + torch.tensor(True), + lambda x: self.linear1(x), + lambda x: self.linear2(x), + (x,), + ) + + def false_branch(x): + return self.linear1(self.linear2(x)) + + return torch.cond(torch.tensor(True), true_branch, false_branch, (x,)) + + mod = Mod() + with FlopCounterMode(mod, display=False) as mode: + mod(torch.randn(4, 4)) + + self.assertEqual( + mode.get_flop_counts(), + { + "Global": {torch.ops.aten.addmm: 256}, + "Mod": {torch.ops.aten.addmm: 256}, + "Mod.linear1": {torch.ops.aten.addmm: 128}, + "Mod.linear2": {torch.ops.aten.addmm: 128}, + }, + ) + + def test_flop_counter_for_cond_unbalanced_branches(self): + from torch.utils.flop_counter import FlopCounterMode + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + def true_branch(x): + return self.linear(x) + + def false_branch(x): + return x.clone() + + return torch.cond(torch.tensor(True), true_branch, false_branch, (x,)) + + mod = Mod() + with FlopCounterMode(mod, display=False) as mode: + mod(torch.randn(4, 4)) + + self.assertEqual( + mode.get_flop_counts(), + { + "Global": {torch.ops.aten.addmm: 128}, + "Mod": {torch.ops.aten.addmm: 128}, + "Mod.linear": {torch.ops.aten.addmm: 128}, + }, + ) + xfail_hops_compile = { # aot_eager diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 632ebdc39278..6d761305d0a8 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -10422,6 +10422,26 @@ def fn(x, y): actual = fn_opt(*inps) expected = fn(*inps) + def test_nested_dataclass_reconstruct(self): + @dataclasses.dataclass(frozen=True) + class NestedDataClass: + x: int = 2 + + @dataclasses.dataclass(frozen=True) + class TestDataClass: + y: torch.Tensor + ndc: NestedDataClass = NestedDataClass() + + def fn(y): + dc = TestDataClass(y) + z = dc.y + dc.ndc.x + return z, dc + + fn_opt = torch.compile()(fn) + inps = (torch.ones(2, 2),) + actual = fn_opt(*inps) + expected = fn(*inps) + def test_frozen_dataclass_default_value(self): @dataclasses.dataclass(frozen=True) class TestDataClass: @@ -12962,6 +12982,38 @@ def f(actions, n_act, epsilon=0.1): y = torch.tensor(5) f(x, y) + def test_dynamic_float_scalar_tensor_coersion(self): + # Minified version of https://github.com/pytorch/pytorch/issues/158376#issuecomment-3079591367 + class Foo: + def __init__(self): + self.config = type( + "Config", (), {"pad_val": 1123581321.0, "tolerance": 1e-6} + ) + + @torch.compile(fullgraph=True) + def forward(self, input): + outputs = torch.where( + torch.abs(input - self.config.pad_val) < self.config.tolerance, + torch.tensor( + self.config.pad_val, dtype=input.dtype, device=input.device + ), + torch.tensor( + self.config.pad_val + 1, dtype=input.dtype, device=input.device + ), + ) + return outputs + + foo = Foo() + inputs = torch.randn(3, 4) + result = foo.forward(inputs) + + original_pad_val = foo.config.pad_val + foo.config.pad_val += 1.0 + result2 = foo.forward(inputs) + + # Previously would crash with: + # RuntimeError: value cannot be converted to type at::Half without overflow + devices = ("cuda", "hpu", "xpu") instantiate_device_type_tests( diff --git a/test/dynamo/test_package.py b/test/dynamo/test_package.py index d75ea975cb74..d43a8d6c5564 100644 --- a/test/dynamo/test_package.py +++ b/test/dynamo/test_package.py @@ -15,10 +15,12 @@ from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache from torch._dynamo.precompile_context import PrecompileContext from torch._functorch import config as functorch_config +from torch._inductor.mock_cache import global_stats, PatchCaches, Stats from torch._inductor.runtime.runtime_utils import cache_dir from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, + skipIfRocm, ) from torch.testing._internal.inductor_utils import HAS_CUDA, HAS_XPU @@ -38,7 +40,9 @@ def setUp(self): DynamoCache.clear() PrecompileContext.clear() - def _save_and_reload(self, expected_backends, expected_dynamo): + def _save_and_reload( + self, expected_backends, expected_dynamo, expected_autotune=None + ): """ Serializes all artifacts, clears all caches, then reloads the serialized artifact Simulates a new process. @@ -54,6 +58,8 @@ def _save_and_reload(self, expected_backends, expected_dynamo): len(cache_info.precompile_aot_autograd_artifacts), expected_backends ) self.assertEqual(len(cache_info.precompile_dynamo_artifacts), expected_dynamo) + if expected_autotune is not None: + self.assertEqual(len(cache_info.autotune_artifacts), expected_autotune) torch._dynamo.reset() DynamoCache.clear() @@ -377,7 +383,7 @@ def fn2(x): DynamoCache.save(package1) DynamoCache.save(package2) - + total_frames = torch._dynamo.convert_frame.FRAME_COUNTER self._save_and_reload(expected_backends=2, expected_dynamo=2) # These should exist because of populate_caches @@ -388,6 +394,7 @@ def fn2(x): result1 = compiled_fn1(arg1) result2 = compiled_fn2(arg2) self.assertEqual(expected, [result1, result2]) + self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) @@ -411,6 +418,7 @@ def fn2(x): result = [compiled_fn1(arg1), compiled_fn2(arg2)] self.assertEqual(expected, result) DynamoCache.clear() + total_frames = torch._dynamo.convert_frame.FRAME_COUNTER self._save_and_reload(expected_backends=2, expected_dynamo=2) @@ -420,6 +428,41 @@ def fn2(x): result1 = compiled_fn1(arg1) result2 = compiled_fn2(arg2) self.assertEqual(expected, [result1, result2]) + self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) + + @parametrize("device", ("cuda", "xpu")) + @torch._dynamo.config.patch(caching_precompile=True) + @skipIfRocm + def test_automatic_dynamo_autotune_cache(self, device): + if device == "cuda" and not HAS_CUDA: + raise unittest.SkipTest("Requires CUDA/Triton") + if device == "xpu" and not HAS_XPU: + raise unittest.SkipTest("Requires XPU/Triton") + + def fn(x, y): + return x.sin() + y + + arg1 = torch.randn(3, 3, device=device) + arg2 = torch.randn(3, 3, device=device) + expected = fn(arg1, arg2).clone() + + with PatchCaches(): + compiled_fn1 = torch.compile(fn, mode="max-autotune") + result = compiled_fn1(arg1, arg2).clone() + self.assertEqual(expected, result) + self.assertEqual(global_stats.autotune_local, Stats(1, 0, 1)) + DynamoCache.clear() + + total_frames = torch._dynamo.convert_frame.FRAME_COUNTER + self._save_and_reload( + expected_backends=1, expected_dynamo=1, expected_autotune=1 + ) + compiled_fn1 = torch.compile(fn, mode="max-autotune") + with torch.compiler.set_stance("fail_on_recompile"): + result1 = compiled_fn1(arg1, arg2).clone() + self.assertEqual(expected, result1) + self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) + self.assertEqual(global_stats.autotune_local, Stats(2, 1, 1)) @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) @@ -439,6 +482,7 @@ def fn(x): # Should cause a recompile expected2 = compiled_fn(arg2) + total_frames = torch._dynamo.convert_frame.FRAME_COUNTER self._save_and_reload(expected_backends=2, expected_dynamo=1) @@ -451,6 +495,7 @@ def fn(x): compiled_fn(arg3) self.assertEqual(result1, expected1) self.assertEqual(result2, expected2) + self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) @@ -486,6 +531,7 @@ def guard_filter_fn(guards): for args in args_list: compiled_fn(*args) + total_frames = torch._dynamo.convert_frame.FRAME_COUNTER self._save_and_reload(expected_backends=8, expected_dynamo=1) compiled_fn = torch._dynamo.optimize( @@ -494,6 +540,8 @@ def guard_filter_fn(guards): with torch.compiler.set_stance("fail_on_recompile"): for args in args_list: self.assertEqual(compiled_fn(*args), args[0].sum()) + # Should have same number of frames as on cold start + self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) @parametrize("device", ("cpu", "cuda", "xpu")) @torch._dynamo.config.patch(caching_precompile=True) @@ -512,6 +560,7 @@ def fn(x): compiled_fn = torch.compile(fn) expected1 = compiled_fn(arg1) expected1.sum().backward() + total_frames = torch._dynamo.convert_frame.FRAME_COUNTER self._save_and_reload(expected_backends=1, expected_dynamo=1) @@ -521,6 +570,8 @@ def fn(x): expected2 = compiled_fn(arg2) expected2.sum().backward() + self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_profiler.py b/test/dynamo/test_profiler.py index 9a7a892d8b02..860b337e95f7 100644 --- a/test/dynamo/test_profiler.py +++ b/test/dynamo/test_profiler.py @@ -192,6 +192,47 @@ def fn(x, y): ], ) + def test_profiler_enabled(self): + def fn(x): + x = torch.sin(x) + if torch.autograd._profiler_enabled(): + return torch.cos(x) + else: + return torch.sigmoid(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + + with torch.autograd.profiler.profile(): + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_profiler_record_function_ignore(self): + def fn(x): + x = torch.sin(x) + if torch.autograd._profiler_enabled(): + with torch.autograd.profiler.record_function("dummy"): + return torch.cos(x) + else: + return torch.sigmoid(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + + with torch.autograd.profiler.profile(): + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 89202b9037e5..db1288fe5bf9 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -43,7 +43,13 @@ import torch.utils._pytree as pytree from torch import nn from torch._dynamo.debug_utils import same_two_models -from torch._dynamo.testing import CompileCounter, rand_strided, same, skipIfPy312 +from torch._dynamo.testing import ( + CompileCounter, + rand_strided, + same, + skipIfNotPy312, + skipIfPy312, +) from torch._inductor.utils import fresh_cache from torch.nn import functional as F from torch.profiler import profile, ProfilerActivity @@ -986,7 +992,7 @@ def tearDown(self) -> None: self.exit_stack.close() super().tearDown() - def guard_manager_clone_hook_fn(self, guard_manager_wrapper, f_locals): + def guard_manager_clone_hook_fn(self, guard_manager_wrapper, f_locals, builder): root = guard_manager_wrapper.root cloned_root = root.clone_manager(lambda x: True) cloned_wrapper = torch._dynamo.guards.GuardManagerWrapper(cloned_root) @@ -2036,8 +2042,13 @@ def fn(x): ref0 = fn(x) ref1 = fn(x) - random.seed(0) opt_fn = torch.compile(fn, backend="eager") + # Especially for internal usage, there are many calls to random functions + # on first compile, e.g., from various library initializations. Run once + # to get that out of the way before resetting the seed: + opt_fn(x) + + random.seed(0) res0 = opt_fn(x) res1 = opt_fn(x) @@ -7067,6 +7078,50 @@ def f(x, out): torch.compile(f, backend="eager", fullgraph=True)(x, out_res) self.assertEqual(out_ref, out_res) + @skipIfNotPy312 + def test_sys_monitoring(self): + found_dynamo = False + found_compiled_graph = False + compiled_graph = None + + def backend(gm, _): + nonlocal compiled_graph + compiled_graph = gm + return gm + + def callback(code, offset): + nonlocal found_dynamo + nonlocal found_compiled_graph + torch._dynamo.graph_break() + if ( + code + is torch._dynamo.symbolic_convert.InstructionTranslator.run.__code__ + ): + found_dynamo = True + elif compiled_graph and code is compiled_graph.__call__.__code__: + found_compiled_graph = True + + sys.monitoring.use_tool_id(0, "test") + old_callback = sys.monitoring.register_callback( + 0, sys.monitoring.events.PY_START, callback + ) + sys.monitoring.set_events(0, sys.monitoring.events.PY_START) + try: + + @torch.compile(backend=backend, fullgraph=True) + def fn(x): + return x + 1 + + fn(torch.ones(3)) + # sys.monitoring should still run in Python dynamo + self.assertTrue(found_dynamo) + # sys.monitoring should still run on the compiled graph + self.assertTrue(found_compiled_graph) + finally: + sys.monitoring.register_callback( + 0, sys.monitoring.events.PY_START, old_callback + ) + def test_unbind_copy_out(self): def f(eye, out): torch.unbind_copy(eye, out=out) @@ -7569,6 +7624,81 @@ def f(x): with mock.patch("torch.cuda.is_initialized", lambda: False): self.assertEqual(f(inp), inp + 2) + def test_named_tuple_vt_clone(self): + # https://github.com/pytorch/pytorch/issues/157945 + class SVDCompressor(nn.Module): + def __init__(self, k=10): + super().__init__() + self.k = k + + def forward(self, x): + U, S = torch.linalg.svd(x)[:2] + reduced = U[:, :, : self.k] @ torch.diag_embed(S[:, : self.k]) + return reduced + + input = torch.randn(4, 8, 6) + model = SVDCompressor(k=5) + + out1 = model(input.clone()) + out2 = torch.compile(model, backend="eager")(input.clone()) + self.assertEqual(out1, out2) + + def test_filter_warnings(self): + x = torch.ones(2, 2, requires_grad=True) + + def call_foobar(x): + warnings.warn("foobar") + + @torch.compile(backend="eager") + def f(x): + call_foobar(x) + call_foobar(x) + call_foobar(x) + call_foobar(x) + return call_foobar(x) + + with warnings.catch_warnings(record=True) as w: + f(x) + self.assertEqual(len(w), 1) + self.assertEqual(str(w[0].message), "foobar") + + def test_filter_safe_grad_warning(self): + x = torch.ones(2, 2, requires_grad=True) + y = x * 5 # non-leaf, .grad should warn + torch._subclasses.meta_utils.safe_grad(y) # filters out warning + + def unsafe_grad(y): + return y.grad + + with warnings.catch_warnings(record=True) as w: + unsafe_grad(y) # should still warn, different callsite + self.assertEqual(len(w), 1) + self.assertTrue("The .grad attribute of a Tensor" in str(w[0].message)) + + unsafe_grad(y) # should not warn + self.assertEqual(len(w), 1) + + def test_filter_user_warnings(self): + x = torch.ones(2, 2, requires_grad=True) + y = x * 5 # non-leaf, .grad should warn + + @torch._dynamo.eval_frame.TorchPatcher.suppress_torch_distributed_warnings + def mute_warn(y): + return y.grad + + mute_warn(y) # filters out warning + + def unsafe_grad(y): + return y.grad + + with warnings.catch_warnings(record=True) as w: + unsafe_grad(y) # should still warn, different callsite + self.assertEqual(len(w), 1) + self.assertTrue("The .grad attribute of a Tensor" in str(w[0].message)) + + unsafe_grad(y) # should not warn + self.assertEqual(len(w), 1) + instantiate_parametrized_tests(ReproTests) diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index 72fe71ace2da..cde880df17a6 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -247,6 +247,7 @@ def test_schedule(self): {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -256,7 +257,6 @@ def test_schedule(self): {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "triton_kernel_info", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"compilation_metrics_runtime": "METRICS", "frame_id": 0, "frame_compile_id": 0} @@ -279,6 +279,7 @@ def test_cudagraphs(self): {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -288,7 +289,6 @@ def test_cudagraphs(self): {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "triton_kernel_info", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"compilation_metrics_runtime": "METRICS", "frame_id": 0, "frame_compile_id": 0} @@ -319,6 +319,7 @@ def fn(x, y): {"dynamo_output_graph": {"sizes": {"l_y_": [1000, 1000], "l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -327,7 +328,6 @@ def fn(x, y): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"artifact": {"name": "recompile_reasons", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} @@ -339,6 +339,7 @@ def fn(x, y): {"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} @@ -347,7 +348,6 @@ def fn(x, y): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0} """, # noqa: B950 @@ -369,6 +369,7 @@ def test_example_fn(self): {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000], "ones_1": [1000, 1000], "output_1": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -377,7 +378,6 @@ def test_example_fn(self): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} """, # noqa: B950 @@ -424,6 +424,7 @@ def test_example_training_fn(self): {"dynamo_output_graph": {"sizes": {"l_stack0_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000], "sum_1": []}}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"aot_joint_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} @@ -434,7 +435,6 @@ def test_example_training_fn(self): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1} {"dynamo_start": {"stack": "STACK"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} @@ -506,6 +506,7 @@ def throw(x): {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_joint_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -656,6 +657,7 @@ def forward(self, x): {"describe_source": {"describer_id": "ID", "id": 2, "source": "L['self']._modules['layers']._modules['0']._parameters['bias']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_joint_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -666,7 +668,6 @@ def forward(self, x): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"describe_storage": {"id": 16, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_tensor": {"id": 29, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1024, 1], "storage": 16, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 29, "source": "L['self']._modules['layers']._modules['1']._parameters['weight']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} @@ -675,6 +676,7 @@ def forward(self, x): {"describe_source": {"describer_id": "ID", "id": 30, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_joint_graph": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -685,7 +687,6 @@ def forward(self, x): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} """, # noqa: B950 @@ -716,6 +717,7 @@ def fn(x): {"dynamo_output_graph": {"sizes": {"l_x_": [1], "add": [1]}}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -724,7 +726,6 @@ def fn(x): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 1, "frame_compile_id": 0, "attempt": 0} """, # noqa: B950 @@ -875,6 +876,7 @@ def fn(a): {"dynamo_output_graph": {"sizes": {"l_a_": [1], "sin": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -883,7 +885,6 @@ def fn(a): {"artifact": {"name": "after_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} @@ -1039,10 +1040,10 @@ def backward(ctx, gO): '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 9, "frame_compile_id": 0, "attempt": 0}', '{"dynamo_start": {"stack": "STACK"}, "frame_id": 1, "frame_compile_id": 1, "attempt": 0}', '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 6, "frame_compile_id": 1, "attempt": 0}', - '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 7, "frame_compile_id": 1, "attempt": 0}', - '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 11, "frame_compile_id": 0, "attempt": 0}', - '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 10, "frame_compile_id": 1, "attempt": 0}', - '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 11, "frame_compile_id": 1, "attempt": 0}', + '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 9, "frame_compile_id": 1, "attempt": 0}', + '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 13, "frame_compile_id": 0, "attempt": 0}', + '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 12, "frame_compile_id": 1, "attempt": 0}', + '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 13, "frame_compile_id": 1, "attempt": 0}', ] logs = self.buffer.getvalue() self.assertTrue(all(event in logs for event in expected)) diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 8ae4c9e58343..70ba2a8bd1bd 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -132,6 +132,9 @@ def fn(shape): res1 = fn(shape) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch.compile(fn, backend=cnts) + # Especially for internal: before resetting the seed, first shake out any rng + # calls that occur on compile, e.g., as a result of some module initializations. + opt_fn(shape) random.seed(1) res2 = opt_fn(shape) @@ -151,6 +154,9 @@ def fn(x): res1 = fn(x) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch.compile(fn, backend=cnts) + # Especially for internal: before resetting the seed, first shake out any rng + # calls that occur on compile, e.g., as a result of some module initializations. + opt_fn(x) random.seed(1) res2 = opt_fn(x) self.assertTrue(same(res1, res2)) @@ -176,6 +182,9 @@ def fn(x): res1 = fn(x) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch.compile(fn, backend=cnts) + # Especially for internal: before resetting the seed, first shake out any rng + # calls that occur on compile, e.g., as a result of some module initializations. + opt_fn(x) random.seed(1) res2 = opt_fn(x) self.assertTrue(same(res1, res2)) @@ -206,6 +215,9 @@ def fn(x): random.seed(1) res1 = fn(x) opt_fn = torch.compile(fn, backend="eager") + # Especially for internal: before resetting the seed, first shake out any rng + # calls that occur on compile, e.g., as a result of some module initializations. + opt_fn(x) random.seed(1) res2 = opt_fn(x) self.assertTrue(same(res1, res2)) @@ -232,6 +244,9 @@ def fn(x, rand2): random.seed(0) y_1, rand2_1, rand3_1 = fn(inp, random.Random(12)) state_1 = random.getstate() + # Especially for internal: before resetting the seed, first shake out any rng + # calls that occur on compile, e.g., as a result of some module initializations. + opt_fn(inp, random.Random(12)) random.seed(0) y_2, rand2_2, rand3_2 = opt_fn(inp, random.Random(12)) state_2 = random.getstate() diff --git a/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_whitespace b/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_whitespace deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_double_specials_do_unpack b/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_double_specials_do_unpack deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_float_specials_do_unpack b/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_float_specials_do_unpack deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_serialized_float_rounding b/test/dynamo_expected_failures/CPython313-test_float-IEEEFormatTestCase.test_serialized_float_rounding deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_basic b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_basic deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_big_range b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_big_range deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_class_for b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_class_for deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_class_iter b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_class_iter deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_dict b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_dict deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_empty b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_empty deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_for_loop b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_for_loop deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_range b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_range deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_string b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_string deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_tuple b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_iter_tuple deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_dict b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_dict deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_list b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_list deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_range b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_range deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_string b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_string deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_tuple b/test/dynamo_expected_failures/CPython313-test_iter-TestCase.test_sinkstate_tuple deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-IsCloseTests.test_negative_tolerances b/test/dynamo_expected_failures/CPython313-test_math-IsCloseTests.test_negative_tolerances deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAcos b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAcos deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAcosh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAcosh deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAsin b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAsin deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAsinh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAsinh deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtan b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtan deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtan2 b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtan2 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtanh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testAtanh deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCbrt b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCbrt deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCopysign b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCopysign deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCos b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCos deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCosh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testCosh deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testDegrees b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testDegrees deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testExp b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testExp deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testExp2 b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testExp2 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFabs b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFabs deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFactorialHugeInputs b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFactorialHugeInputs deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFmod b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFmod deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFrexp b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testFrexp deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLdexp b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLdexp deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog10 b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog10 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog1p b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testLog1p deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testModf b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testModf deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testPow b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testPow deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testRadians b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testRadians deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSin b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSin deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSinh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSinh deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSqrt b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testSqrt deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testTan b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testTan deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testTanh b/test/dynamo_expected_failures/CPython313-test_math-MathTests.testTanh deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_exceptions b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_exceptions deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_input_exceptions b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_input_exceptions deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_math_dist_leak b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_math_dist_leak deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_nextafter b/test/dynamo_expected_failures/CPython313-test_math-MathTests.test_nextafter deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_from_hex b/test/dynamo_expected_failures/TestAutogradFallback.test_no_autograd_kernel_inplace_mode_nothing similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_from_hex rename to test/dynamo_expected_failures/TestAutogradFallback.test_no_autograd_kernel_inplace_mode_nothing diff --git a/test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_invalid_inputs b/test/dynamo_expected_failures/TestAutogradFallback.test_no_autograd_kernel_inplace_mode_warn similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_float-HexFloatTestCase.test_invalid_inputs rename to test/dynamo_expected_failures/TestAutogradFallback.test_no_autograd_kernel_inplace_mode_warn diff --git a/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_complex128 b/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_complex128 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_complex64 b/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_complex64 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_float32 b/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_float64 b/test/dynamo_expected_failures/TestLinalgCPU.test_slogdet_errors_and_warnings_cpu_float64 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index ee9d466f6083..a590713ad0f8 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -75,7 +75,6 @@ aten::_ctc_loss.out aten::_ctc_loss_backward aten::_ctc_loss_backward.Tensor aten::_ctc_loss_backward.out -aten::_cudnn_attention_backward aten::_cudnn_attention_forward aten::_cudnn_ctc_loss aten::_cudnn_ctc_loss.Tensor @@ -375,7 +374,6 @@ aten::_fused_adamw_.tensor_lr aten::_fused_moving_avg_obs_fq_helper aten::_fused_moving_avg_obs_fq_helper.out aten::_fused_moving_avg_obs_fq_helper_functional -aten::_fused_rms_norm aten::_fused_sdp_choice aten::_fused_sgd aten::_fused_sgd.out diff --git a/test/export/test_converter.py b/test/export/test_converter.py index 953246be7a7b..d5611ad2d579 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -700,7 +700,7 @@ def forward(self, x: torch.Tensor): else: return self.w + self.m2(x) - # Super nested, parameters neeed to lifted + # Super nested, parameters need to be lifted # multiple times. class SuperNestedM(torch.nn.Module): def __init__(self) -> None: @@ -755,7 +755,7 @@ def forward(self, x: torch.Tensor): else: return self.linear(self.m2(x)) - # Super nested, parameters neeed to lifted + # Super nested, parameters need to be lifted # multiple times. class SuperNestedM1(torch.nn.Module): def __init__(self, dim: int) -> None: @@ -771,7 +771,7 @@ def forward(self, x: torch.Tensor): return self.linear(self.m2(x)) # Super nested, even the input needs to be - # lifted recursively due to value propogation optimiztaion. + # lifted recursively due to value propagation optimization. class SuperNestedM2(torch.nn.Module): def __init__(self, dim: int) -> None: super().__init__() diff --git a/test/export/test_export.py b/test/export/test_export.py index 2ded21ec87e0..497122e3cc75 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -36,6 +36,7 @@ from torch._higher_order_ops.associative_scan import associative_scan from torch._higher_order_ops.hints_wrap import hints_wrapper from torch._higher_order_ops.scan import scan +from torch._higher_order_ops.while_loop import while_loop from torch._inductor.compile_fx import split_const_gm from torch._subclasses import FakeTensorMode from torch.export import ( @@ -249,6 +250,10 @@ def is_training_ir_test(test_name): ) +def is_training_ir_strict_test(test_name): + return test_name.endswith(TRAINING_IR_DECOMP_STRICT_SUFFIX) + + def is_cpp_runtime_test(test_name): return test_name.endswith(CPP_RUNTIME_STRICT_SUFFIX) or test_name.endswith( CPP_RUNTIME_NONSTRICT_SUFFIX @@ -927,7 +932,6 @@ def forward(self, x): ep = export(f, args, strict=False) self.assertEqual(ep.module()(*args), f(*args)) - @testing.expectedFailureCppSerDes # Cpp Ser/Der seems to fail parsing complicated guards def test_export_statically_known_true(self): class Foo(torch.nn.Module): def forward(self, x, y): @@ -1583,6 +1587,39 @@ def forward(self): ) self.assertEqual(m(*args), ep.module()(*args)) + def test_cond_access_identical_symint_closure(self): + class Example2(torch.nn.Module): + def forward(self, x, trigger, target): + return torch.cond( + trigger == 1, + lambda: x + target, + lambda: x * target, + (), + ) + + m = Example2() + x = torch.randn(2) + trigger = 0 + target = 2 + args = (x, trigger, target) + ep = export(m, args, dynamic_shapes=(None, Dim.DYNAMIC, Dim.DYNAMIC)) + if is_training_ir_strict_test(self._testMethodName): + # In strict mode export's result capturing compiler, we create + # 2 new symints when re-fakifying the symint inputs. + # Then in run_decompositions, ep.range_constraints was updated + # where it checks the var_to_range and put the two newly added ones into the range_constraints. + self.assertExpectedInline( + str(tuple(ep.range_constraints.values())), + """(VR[0, int_oo], VR[0, int_oo], VR[-int_oo, int_oo], VR[-int_oo, int_oo])""", + ) + else: + self.assertExpectedInline( + str(tuple(ep.range_constraints.values())), + """(VR[0, int_oo], VR[0, int_oo])""", + ) + + self.assertEqual(m(*args), ep.module()(*args)) + def test_cond_branches_return_same_int(self): class M(torch.nn.Module): def forward(self, x): @@ -1773,6 +1810,36 @@ def forward(self, x): ): export(M(), (torch.randn(2, 3),), strict=False) + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_while_loop_tensor_constant_idx(self): + def while_loop_decomp(x, y0): + out = torch.zeros_like(x) + + def cond_fn(idx, out, y0): + return idx < out.size(0) + + def body_fn(idx, out, y0): + i = idx.item() + torch._check_is_size(i, max=x.size(0) - 1) + y0 = x[i] + y0 + out = out.clone() + out[i] = y0 + return idx + 1, out, y0 + + cnt = torch.tensor(0) + _, out, _ = while_loop(cond_fn, body_fn, [cnt, out, y0]) + return out + + class TestModel(torch.nn.Module): + def forward(self, x, y0): + return while_loop_decomp(x, y0) + + x, y0 = torch.randn(16, 8), torch.randn(8) + exp_out = TestModel()(x, y0) + ep = export(TestModel(), (x, y0)) + out = ep.module()(x, y0) + self.assertEqual(exp_out, out) + def test_malformed_fqn_from_source_name(self): # See https://github.com/pytorch/pytorch/issues/141939 from types import MethodType @@ -1831,7 +1898,7 @@ def annotate_split_points(mod: torch.nn.Module, spec): for problem in [Problem1, Problem2]: m = problem() m(torch.rand(64, 64)) - # simpified torch.distributed.pipeline code + # simplified torch.distributed.pipeline code annotate_split_points(m, {"blocks.1": 1, "blocks.3": 1}) gm = export(m, (torch.rand(64, 64),)) torch.export.unflatten(gm) @@ -5011,7 +5078,6 @@ def forward(self, x): # There should be nonzero view nodes in the graph self.assertTrue(view_count > 0) - @testing.expectedFailureCppSerDes # cpp Ser/Der not handling complicated symbols def test_solver_unsupported_sympy_function(self): # repro of https://github.com/pytorch/pytorch/issues/131897 @@ -7560,6 +7626,69 @@ def forward(self, inputs): ]: self.assertFalse(hasattr(tensor, attr)) + @testing.expectedFailureCppRuntime + def test_while_loop_index_assertions(self): + from torch._higher_order_ops import while_loop + + class Foo(torch.nn.Module): + def forward(self, x): + def cond_fn(idx, acc): + i = idx.item() + return i < x.size(0) + + def body_fn(idx, acc): + # this check_is_size call needs to be traced by this subgraph for the select call, + # it can't be in the cond graph, as that fires & fails right before loop termination. + i = idx.item() + torch._check_is_size(i, max=x.size(0) - 1) + return idx + 1, acc + x[i] + + acc = torch.zeros(x.size(1)) + n = torch.full((), 0, dtype=torch.int64) + _, out = while_loop(cond_fn, body_fn, [n, acc]) + return out + + x = torch.randn(8, 4) + ep = export(Foo(), (x,), strict=False) + self.assertTrue(torch.allclose(x.sum(dim=0), ep.module()(x))) + + @testing.expectedFailureCppRuntime + def test_while_loop_assert_separation(self): + from torch._higher_order_ops import while_loop + + class Bar(torch.nn.Module): + def forward(self, idx, x): + i = idx.item() + + def cond_fn(idx, x): + i = idx.item() + torch._check(i != 5) + return i <= 9 + + def body_fn(idx, x): + i = idx.item() + torch._check(i % 2 == 0) + return idx + 2, x + i + + return while_loop(cond_fn, body_fn, [idx, x + i]) + + inps = (torch.tensor([0]), torch.zeros(1)) + ep = export(Bar(), inps, strict=False) + i, out = ep.module()(*inps) + self.assertEqual(i, 10) + self.assertEqual(out.item(), 20) + + # check assertions are separate for each subgraph + with self.assertRaisesRegex( + RuntimeError, r"Runtime assertion failed for expression Ne\(u[\d]+, 5\).*" + ): + ep.graph_module.while_loop_cond_graph_0(torch.tensor([5]), torch.zeros(1)) + with self.assertRaisesRegex( + RuntimeError, + r"Runtime assertion failed for expression Eq\(PythonMod\(u[\d]+, 2\), 0\).*", + ): + ep.graph_module.while_loop_body_graph_0(torch.tensor([5]), torch.zeros(1)) + def test_constrain_decomp(self) -> None: class M(torch.nn.Module): def __init__(self) -> None: @@ -8056,7 +8185,7 @@ def false_fn(x): str(schema), """cond(SymBool pred, GraphModule true_fn, GraphModule false_fn, Tensor[2] operands) -> Tensor[1]""", ) - # serdes deserailizes tuple as list + # serdes deserializes tuple as list if need_serdes_test(self._testMethodName): self.assertExpectedInline( ep.graph_module.code.strip(), @@ -8742,10 +8871,6 @@ def forward(self, x): inp = torch.randn(2) self.assertTrue(torch.allclose(ep.module()(inp), torch.nonzero(inp))) - # TODO(pianpwk) blocker: https://github.com/pytorch/pytorch/issues/151809 - @testing.expectedFailureSerDer - @testing.expectedFailureSerDerNonStrict - @testing.expectedFailureCppSerDes def test_redundant_asserts(self): class Foo(torch.nn.Module): def forward(self, x): @@ -9192,7 +9317,7 @@ def forward(self, x): x = torch.rand(5, 2, 2) model = Model() - # Manualy set the fake_device of fake tensors. + # Manually set the fake_device of fake tensors. x.fake_device = torch.device("cuda:0") for n, p in model.named_parameters(): p.fake_device = torch.device("cuda:0") @@ -13495,9 +13620,6 @@ def forward(self, x, y): ): ep.module()(torch.randn(10), torch.tensor(2)) - @testing.expectedFailureCppSerDes # TODO: When we deserialize we somehow hardcode sympy.lower to 2 - @testing.expectedFailureSerDerNonStrict - @testing.expectedFailureSerDer @torch.fx.experimental._config.patch(backed_size_oblivious=True) def test_baddbmm(self): class M(torch.nn.Module): @@ -13522,7 +13644,7 @@ def forward(self, x): self.assertTrue(torch.allclose(m(x2), ep.module()(x2))) self.assertTrue(torch.allclose(m(x1), ep.module()(x1))) - @testing.expectedFailureSerDerNonStrict # construtor is not serialized today + @testing.expectedFailureSerDerNonStrict # constructor is not serialized today @testing.expectedFailureSerDer # constructor is not serialized today @testing.expectedFailureRetraceability # dynamo doesn't work with FlatApply op def test_capture_subclass_constructor(self): diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 226404737e26..d174405dd8e0 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -1816,6 +1816,60 @@ def forward(self, x): self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo") self.assertEqual(counter, 1) + def test_unbacked_range_serdes(self): + class Foo(torch.nn.Module): + def forward(self, x, y): + n = x.item() + torch._check_is_size(n, max=y.size(0) - 1) + return torch.empty(n), y[n] + + ep = torch.export.export( + Foo(), + (torch.tensor([5]), torch.randn(10)), + dynamic_shapes={ + "x": None, + "y": (Dim.DYNAMIC,), + }, + ) + buffer = io.BytesIO() + save(ep, buffer) + buffer.seek(0) + loaded_ep = load(buffer) + + # pre-serialize ep + pre_shape_env = torch._guards.detect_fake_mode( + [node.meta.get("val") for node in ep.graph.nodes] + ).shape_env + post_shape_env = torch._guards.detect_fake_mode( + [node.meta.get("val") for node in loaded_ep.graph.nodes] + ).shape_env + self.assertEqual(pre_shape_env.var_to_range, post_shape_env.var_to_range) + + def test_backed_size_oblivious_serdes(self): + class Foo(torch.nn.Module): + def forward(self, x, y, z): + return x + y + z.item() + + with torch.fx.experimental._config.patch(backed_size_oblivious=True): + ep = torch.export.export( + Foo(), + (torch.randn(1), torch.randn(1), torch.tensor([5])), + dynamic_shapes={ + "x": (Dim.DYNAMIC,), + "y": (Dim.DYNAMIC,), + "z": None, + }, + ) + buffer = io.BytesIO() + save(ep, buffer) + buffer.seek(0) + loaded_ep = load(buffer) + shape_env = torch._guards.detect_fake_mode( + [node.meta.get("val") for node in loaded_ep.graph.nodes] + ).shape_env + s0 = next(iter(ep.graph.nodes)).meta["val"].size(0) + self.assertEqual(shape_env.var_to_range[s0.node.expr].lower, 0) + if __name__ == "__main__": run_tests() diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index 3f8f11aca0e5..214f3ce2fdfa 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -1274,7 +1274,7 @@ def forward(self, tq, x): self.assertEqual(cnt.frame_count, 1) tq2 = _empty_tensor_queue() - # make first tensor's secon dim dynamic + # make first tensor's second dim dynamic tq2.push(torch.randn(2, 4, requires_grad=False)) torch.compile(mod, backend=cnt)(tq2, x) self.assertEqual(cnt.frame_count, 2) diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index d6cf2df4343f..5a962dfa57c0 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -139,6 +139,8 @@ # These ops are defined in torch/csrc/distributed/c10d/Ops.cpp # TODO: add back restriction when c10d ops can be exported ("c10d::.*", datetime.date(9999, 1, 1)), + # Previously MPS_only did not support backward + ("aten::_fused_rms_norm", datetime.date(2025, 12, 30)), ] ALLOW_LIST_COMPILED = [ diff --git a/test/functorch/attn_ft.py b/test/functorch/attn_ft.py index ee4656631964..7038ded09490 100644 --- a/test/functorch/attn_ft.py +++ b/test/functorch/attn_ft.py @@ -126,7 +126,7 @@ def forward( if self.position_embedding_type == "relative_key": # these were einsum ops in the positional code because they are not easy to fit to existing matmul operators - # eventhough they are degenerate matmuls + # even though they are degenerate matmuls relative_position_scores = (q * positional_embedding).sum(features) attention_scores = attention_scores + relative_position_scores elif self.position_embedding_type == "relative_key_query": diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index f1d1c92d52f4..869fc6964f2f 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -2372,7 +2372,7 @@ def f(a, b): return a.mul(3), b.mul(4) inp = [ - # First inp doesnt require grad, but we switch it on + # First inp doesn't require grad, but we switch it on torch.ones(3, 3, requires_grad=False), torch.ones(3, 3, requires_grad=True), ] @@ -5670,7 +5670,7 @@ def f(a, b, c, d): _, fw_graph_out_nodes = get_ins_outs(fw_graph) self.assertEqual( # fw outputs include b.size() which expands to 2 symints, - # then 4 tensors (transposes of matricies used for mm) are saved + # then 4 tensors (transposes of matrices used for mm) are saved # finally 3 symints are saved [False, True, True, False, False] + [False] * 4 + [True] * 3, [is_sym_node(n) for n in fw_graph_out_nodes], @@ -6000,7 +6000,7 @@ def f(a, b): self.assertEqual(b_test.a, b_ref.a) self.assertEqual(b_test.b, b_ref.b) - # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile teh backward. + # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile the backward. (b_ref * out_ref).sum().backward() (b_test * out_test).sum().backward() # Both grad_inputs are TwoTensor @@ -7469,7 +7469,7 @@ def test_saved_tensors_hooks_donated_buffers(self): "pack_hash", "unpack_hash", ) - logger_name = "torch._functorch._aot_autograd.jit_compile_runtime_wrappers" + logger_name = "torch._functorch._aot_autograd.graph_compile" class SAF(torch.autograd.Function): @staticmethod diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 1508997384d2..54ccd0f7fef2 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -1409,7 +1409,6 @@ def f(x, y): f, (torch.ones(3, 4, 5), torch.ones(4, 4, 5)), torch.ones(5) ) - @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_map_illegal_outputs(self): def f(x, y): return x.item() @@ -2013,7 +2012,7 @@ def test_scan_complex_pytree(self, reverse, compile_mode, device, autograd): if autograd: self.check_autograd(result, expected_result, (init, inp)) - # TODO: Does not work because of the usage of vmap witin associative_scan + # TODO: Does not work because of the usage of vmap within associative_scan # The paT206899919 rameterization is commented out for the moment and the test is marked with expected fail # Fails with: AssertionError: scan is not an OpOverload @skipIfRocm(msg="Unsupported on ROCM yet") @@ -4143,7 +4142,7 @@ def second_chain_fct(scan_fct, inp, **kwargs): inputs=inp, ) - # TODO: Does not work because of the usage of vmap witin associative_scan + # TODO: Does not work because of the usage of vmap within associative_scan # TODO: Re-enable additional parameters again once this issues has been resolved @unittest.skipIf(not SM70OrLater, "triton") @requires_cuda @@ -4241,7 +4240,7 @@ def body_fn(ind, loop_val): inputs=inp, ) - # TODO: Does not work because of the usage of vmap witin associative_scan + # TODO: Does not work because of the usage of vmap within associative_scan # TODO: Re-enable additional parameters again once this issues has been resolved @unittest.skipIf(not SM70OrLater, "triton") @requires_cuda @@ -4314,7 +4313,7 @@ def combine_fn(x, y): inputs=inp, ) - # TODO: Does not work because of the usage of vmap witin associative_scan + # TODO: Does not work because of the usage of vmap within associative_scan # TODO: Re-enable additional parameters again once this issues has been resolved @unittest.skipIf(not SM70OrLater, "triton") @requires_cuda @@ -5315,7 +5314,7 @@ def forward(self, arg0_1): ) @parametrize("func_type", ["no", "cpp", "python", "functorch"]) - # - "simple_with_linear" and "nested_with_linear" doesn't work becaue parameters and buffers + # - "simple_with_linear" and "nested_with_linear" doesn't work because parameters and buffers # are not inputs so they're not wrapped by functionalization and tracing. # # - make_fx tracing mode "real" fails for "int_carry", "pytree_int_carry" and "const_and_symint_output" @@ -5392,18 +5391,18 @@ def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor, L_self_buffers_de self.assertExpectedInline( gm.cond_fn_0.code.strip(), """\ -def forward(self, l_iter_ : torch.Tensor, l_x_ : torch.Tensor, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn): - sub = l_iter_ - l_self_buffers_dec__cond_fn; l_iter_ = l_self_buffers_dec__cond_fn = None +def forward(self, child : torch.Tensor, child_1 : torch.Tensor, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn): + sub = child - l_self_buffers_dec__cond_fn; child = l_self_buffers_dec__cond_fn = None gt = sub > 0; sub = None return gt""", # noqa: B950 ) self.assertExpectedInline( gm.body_fn_0.code.strip(), """\ -def forward(self, l_iter_ : torch.Tensor, l_x_ : torch.Tensor, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn): - child = l_iter_ - 1; l_iter_ = None - child_1 = torch._C._nn.linear(l_x_, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn); l_x_ = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None - return (child, child_1)""", # noqa: B950 +def forward(self, child_2 : torch.Tensor, child_3 : torch.Tensor, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn): + child = child_2 - 1; child_2 = None + child_4 = torch._C._nn.linear(child_3, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn); child_3 = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None + return (child, child_4)""", # noqa: B950 ) else: self.assertExpectedInline( @@ -7669,12 +7668,7 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_ self.assertEqual(compiled_out, exp_out) @skipIfTorchDynamo("Skip because we're testing export") - # TODO: we cannot turn on strict=True yet because torch._check for out_it > 0 is - # removed from the graph in dynamo and in non-strict export's graph capturing - # step, we re-run the traced graph module to get graph captured result. - # Since torch._check is removed from graph, we end up getting a data-dependent - # error when we call torch.ones(out_it * 2). - @parametrize("strict", [False]) + @parametrize("strict", [True, False]) @parametrize("dynamic", [True, False]) def test_while_loop_op_int_carry_export(self, strict, dynamic): m, args = WHILE_LOOP_TESTS["int_carry"] @@ -7717,8 +7711,9 @@ def forward(self, x): class while_loop_cond_graph_0(torch.nn.Module): def forward(self, it_1: "Sym(u0)", x_1: "f32[s77, 3]"): - sym_size_int: "Sym(s77)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None - lt: "Sym(u0 < s77)" = it_1 < sym_size_int; it_1 = sym_size_int = None + sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None + + lt: "Sym(u0 < s77)" = it_1 < sym_size_int_1; it_1 = sym_size_int_1 = None return lt class while_loop_body_graph_0(torch.nn.Module): @@ -7758,62 +7753,62 @@ def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"): body_fn_0 = self.body_fn_0 while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (0, l_x_), (s27, s77)); cond_fn_0 = body_fn_0 = l_x_ = s27 = None - getitem_4: "Sym(u1)" = while_loop[0] + getitem_4: "Sym(u2)" = while_loop[0] - ge: "Sym(u1 >= 1)" = getitem_4 >= 1 - _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u1 >= 1 on node 'ge'"); ge = _assert_scalar_default = None + ge: "Sym(u2 >= 1)" = getitem_4 >= 1 + _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u2 >= 1 on node 'ge'"); ge = _assert_scalar_default = None - gt_1: "Sym(u1 > 0)" = getitem_4 > 0 - _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None + gt_1: "Sym(u2 > 0)" = getitem_4 > 0 + _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u2 on node 'gt_1'"); gt_1 = _assert_scalar_default_1 = None out_x: "f32[s77, s27]" = while_loop[1]; while_loop = None - gt: "Sym(u1 > 0)" = getitem_4 > 0 + gt: "Sym(u2 > 0)" = getitem_4 > 0 _check = torch._check(gt); gt = _check = None - add: "Sym(u1 + 1)" = getitem_4 + 1 + add: "Sym(u2 + 1)" = getitem_4 + 1 add_1: "f32[s77, s27]" = getitem_4 + out_x; out_x = None - lt: "Sym(u1 < s77)" = getitem_4 < s77; s77 = None + lt: "Sym(u2 < s77)" = getitem_4 < s77; s77 = None - mul: "Sym(2*u1)" = getitem_4 * 2; getitem_4 = None - ones: "f32[2*u1]" = torch.ones(mul); mul = None + mul: "Sym(2*u2)" = getitem_4 * 2; getitem_4 = None + ones: "f32[2*u2]" = torch.ones(mul); mul = None return (add, add_1, lt, ones) class cond_fn_0(torch.nn.Module): - def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): + def forward(self, unbacked_symint: "Sym(u0)", child: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): s27_1 = s27 s77_1 = s77 - size = l_x_.size(); l_x_ = None + size = child.size(); child = None getitem: "Sym(s77)" = size[0] getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None lt: "Sym(u0 < s77)" = unbacked_symint < getitem; unbacked_symint = getitem = None return lt class body_fn_0(torch.nn.Module): - def forward(self, unbacked_symint: "Sym(u0)", l_x_: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): + def forward(self, unbacked_symint_0: "Sym(u1)", child_1: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): s27_1 = s27 s77_1 = s77 - x_clone: "f32[s77, s27]" = l_x_.clone() + x_clone: "f32[s77, s27]" = child_1.clone() - ge: "Sym(u0 >= 0)" = unbacked_symint >= 0 + ge: "Sym(u1 >= 0)" = unbacked_symint_0 >= 0 _check = torch._check(ge); ge = _check = None - size = l_x_.size(); l_x_ = None + size = child_1.size(); child_1 = None getitem: "Sym(s77)" = size[0] getitem_1: "Sym(s27)" = size[1]; size = getitem_1 = None - lt: "Sym(u0 < s77)" = unbacked_symint < getitem; getitem = None + lt: "Sym(u1 < s77)" = unbacked_symint_0 < getitem; getitem = None _check_1 = torch._check(lt); lt = _check_1 = None - select: "f32[s27]" = x_clone.select(0, unbacked_symint) - select_1: "f32[s27]" = x_clone.select(0, unbacked_symint) - add: "f32[s27]" = select_1 + unbacked_symint; select_1 = None + select: "f32[s27]" = x_clone.select(0, unbacked_symint_0) + select_1: "f32[s27]" = x_clone.select(0, unbacked_symint_0) + add: "f32[s27]" = select_1 + unbacked_symint_0; select_1 = None copy_: "f32[s27]" = select.copy_(add); select = add = copy_ = None - add_1: "Sym(u0 + 1)" = unbacked_symint + 1; unbacked_symint = None + add_1: "Sym(u1 + 1)" = unbacked_symint_0 + 1; unbacked_symint_0 = None return (add_1, x_clone) """, # noqa: B950 ) @@ -7917,30 +7912,30 @@ def forward(self, L_t_: "f32[2, 3]"): sum_1: "f32[]" = l_t_.sum() to: "i64[]" = sum_1.to(torch.int64); sum_1 = None item: "Sym(u0)" = to.item(); to = None - child: "f32[2, 3]" = l_t_.sin() + sin: "f32[2, 3]" = l_t_.sin() cond_fn_0 = self.cond_fn_0 body_fn_0 = self.body_fn_0 - while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (2, 3, 1, 1, 1, 3, item, child), ()); cond_fn_0 = body_fn_0 = item = child = None - - getitem_8: "Sym(u8)" = while_loop[0] - getitem_9: "Sym(u9)" = while_loop[1] - getitem_10: "Sym(u10)" = while_loop[2] - getitem_11: "Sym(u11)" = while_loop[3] - getitem_12: "Sym(u12)" = while_loop[4] - getitem_13: "Sym(u13)" = while_loop[5] - getitem_14: "Sym(u14)" = while_loop[6] - - child_1: "f32[2, 3]" = while_loop[7]; while_loop = None - - add: "Sym(u8 + 1)" = getitem_8 + 1 - add_1: "Sym(u9 + 1)" = getitem_9 + 1 - add_2: "Sym(u10 + 1)" = getitem_10 + 1 - add_3: "Sym(u11 + 1)" = getitem_11 + 1 - add_4: "Sym(u12 + 1)" = getitem_12 + 1 - add_5: "Sym(u13 + 1)" = getitem_13 + 1 - add_6: "Sym(u14 + 1)" = getitem_14 + 1 - add_7: "f32[2, 3]" = child_1 + 1 + while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (2, 3, 1, 1, 1, 3, item, sin), ()); cond_fn_0 = body_fn_0 = item = sin = None + + getitem_8: "Sym(u15)" = while_loop[0] + getitem_9: "Sym(u16)" = while_loop[1] + getitem_10: "Sym(u17)" = while_loop[2] + getitem_11: "Sym(u18)" = while_loop[3] + getitem_12: "Sym(u19)" = while_loop[4] + getitem_13: "Sym(u20)" = while_loop[5] + getitem_14: "Sym(u21)" = while_loop[6] + + child: "f32[2, 3]" = while_loop[7]; while_loop = None + + add: "Sym(u15 + 1)" = getitem_8 + 1 + add_1: "Sym(u16 + 1)" = getitem_9 + 1 + add_2: "Sym(u17 + 1)" = getitem_10 + 1 + add_3: "Sym(u18 + 1)" = getitem_11 + 1 + add_4: "Sym(u19 + 1)" = getitem_12 + 1 + add_5: "Sym(u20 + 1)" = getitem_13 + 1 + add_6: "Sym(u21 + 1)" = getitem_14 + 1 + add_7: "f32[2, 3]" = child + 1 add_8: "f32[2, 3]" = getitem_8 + l_t_; getitem_8 = None add_9: "f32[2, 3]" = getitem_9 + l_t_; getitem_9 = None @@ -7949,7 +7944,7 @@ def forward(self, L_t_: "f32[2, 3]"): add_12: "f32[2, 3]" = getitem_12 + l_t_; getitem_12 = None add_13: "f32[2, 3]" = getitem_13 + l_t_; getitem_13 = None add_14: "f32[2, 3]" = getitem_14 + l_t_; getitem_14 = None - add_15: "f32[2, 3]" = child_1 + l_t_; child_1 = l_t_ = None + add_15: "f32[2, 3]" = child + l_t_; child = l_t_ = None return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9, add_10, add_11, add_12, add_13, add_14, add_15) class cond_fn_0(torch.nn.Module): @@ -7961,10 +7956,10 @@ def forward(self, unbacked_symint: "Sym(u1)", unbacked_symint_0: "Sym(u2)", unba return lt class body_fn_0(torch.nn.Module): - def forward(self, unbacked_symint: "Sym(u1)", unbacked_symint_0: "Sym(u2)", unbacked_symint_1: "Sym(u3)", unbacked_symint_2: "Sym(u4)", unbacked_symint_3: "Sym(u5)", unbacked_symint_4: "Sym(u6)", unbacked_symint_5: "Sym(u7)", child: "f32[2, 3]"): - add: "Sym(u7 + 1)" = unbacked_symint_5 + 1; unbacked_symint_5 = None - child_1: "f32[2, 3]" = child + 1; child = None - return (unbacked_symint_0, unbacked_symint_1, unbacked_symint_2, unbacked_symint_3, unbacked_symint, 0, add, child_1) + def forward(self, unbacked_symint_6: "Sym(u8)", unbacked_symint_7: "Sym(u9)", unbacked_symint_8: "Sym(u10)", unbacked_symint_9: "Sym(u11)", unbacked_symint_10: "Sym(u12)", unbacked_symint_11: "Sym(u13)", unbacked_symint_12: "Sym(u14)", child_1: "f32[2, 3]"): + add: "Sym(u14 + 1)" = unbacked_symint_12 + 1; unbacked_symint_12 = None + child: "f32[2, 3]" = child_1 + 1; child_1 = None + return (unbacked_symint_7, unbacked_symint_8, unbacked_symint_9, unbacked_symint_10, unbacked_symint_6, 0, add, child) """, # noqa: B950 ) @@ -7992,17 +7987,17 @@ def forward(self, x): while_loop_body_graph_0 = self.while_loop_body_graph_0 while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (sym_size_int_1, 3, 2, 2, 3, sin), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = sym_size_int_1 = sin = None - getitem_6: "Sym(u5)" = while_loop[0] - getitem_7: "Sym(u6)" = while_loop[1] - getitem_8: "Sym(u7)" = while_loop[2] - getitem_9: "Sym(u8)" = while_loop[3] - getitem_10: "Sym(u9)" = while_loop[4] + getitem_6: "Sym(u10)" = while_loop[0] + getitem_7: "Sym(u11)" = while_loop[1] + getitem_8: "Sym(u12)" = while_loop[2] + getitem_9: "Sym(u13)" = while_loop[3] + getitem_10: "Sym(u14)" = while_loop[4] getitem_5: "f32[s77, 3]" = while_loop[5]; while_loop = None - add: "Sym(u7 + 1)" = getitem_8 + 1 - add_1: "Sym(u8 + 1)" = getitem_9 + 1 - add_2: "Sym(u9 + 1)" = getitem_10 + 1 + add: "Sym(u12 + 1)" = getitem_8 + 1 + add_1: "Sym(u13 + 1)" = getitem_9 + 1 + add_2: "Sym(u14 + 1)" = getitem_10 + 1 add_3: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_8); getitem_8 = None add_4: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_9); getitem_9 = None @@ -8010,21 +8005,21 @@ def forward(self, x): return pytree.tree_unflatten((getitem_6, getitem_7, add, add_1, add_2, add_3, add_4, add_5, getitem_5), self._out_spec) class while_loop_cond_graph_0(torch.nn.Module): - def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s77, 3]"): - mul: "Sym(u17*u18)" = arg2_1 * arg3_1; arg2_1 = arg3_1 = None - mul_1: "Sym(u17*u18*u19)" = mul * arg4_1; mul = arg4_1 = None - mul_2: "Sym(u15*u16)" = arg0_1 * arg1_1; arg0_1 = arg1_1 = None - lt: "Sym(u17*u18*u19 < u15*u16)" = mul_1 < mul_2; mul_1 = mul_2 = None + def forward(self, arg0_1: "Sym(u20)", arg1_1: "Sym(u21)", arg2_1: "Sym(u22)", arg3_1: "Sym(u23)", arg4_1: "Sym(u24)", arg5_1: "f32[s77, 3]"): + mul: "Sym(u22*u23)" = arg2_1 * arg3_1; arg2_1 = arg3_1 = None + mul_1: "Sym(u22*u23*u24)" = mul * arg4_1; mul = arg4_1 = None + mul_2: "Sym(u20*u21)" = arg0_1 * arg1_1; arg0_1 = arg1_1 = None + lt: "Sym(u22*u23*u24 < u20*u21)" = mul_1 < mul_2; mul_1 = mul_2 = None return lt class while_loop_body_graph_0(torch.nn.Module): - def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s77, 3]"): - add: "Sym(u15 + 1)" = arg0_1 + 1; arg0_1 = None - add_1: "Sym(u16 + 1)" = arg1_1 + 1; arg1_1 = None + def forward(self, arg0_1: "Sym(u20)", arg1_1: "Sym(u21)", arg2_1: "Sym(u22)", arg3_1: "Sym(u23)", arg4_1: "Sym(u24)", arg5_1: "f32[s77, 3]"): + add: "Sym(u20 + 1)" = arg0_1 + 1; arg0_1 = None + add_1: "Sym(u21 + 1)" = arg1_1 + 1; arg1_1 = None - add_2: "Sym(u17 + 1)" = arg2_1 + 1; arg2_1 = None - add_3: "Sym(u18 + 1)" = arg3_1 + 1; arg3_1 = None - add_4: "Sym(u19 + 1)" = arg4_1 + 1; arg4_1 = None + add_2: "Sym(u22 + 1)" = arg2_1 + 1; arg2_1 = None + add_3: "Sym(u23 + 1)" = arg3_1 + 1; arg3_1 = None + add_4: "Sym(u24 + 1)" = arg4_1 + 1; arg4_1 = None add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None return (add, add_1, add_2, add_3, add_4, add_5) @@ -8034,7 +8029,6 @@ def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", ar @skipIfTorchDynamo("Graph is not captured correctly when test with dynamo") @parametrize("dynamic", [True, False]) @parametrize("backend", ["eager", "aot_eager"]) - @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_while_loop_op_pytree_int_carry_compile(self, dynamic, backend): m, args = WHILE_LOOP_TESTS["pytree_int_carry"] if backend == "eager": @@ -8059,17 +8053,17 @@ def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"): body_fn_0 = self.body_fn_0 while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (s77, s27, 2, 2, 3, child), (s27, s77)); cond_fn_0 = body_fn_0 = s77 = s27 = child = None - getitem_10: "Sym(u5)" = while_loop[0] - getitem_11: "Sym(u6)" = while_loop[1] - getitem_12: "Sym(u7)" = while_loop[2] - getitem_13: "Sym(u8)" = while_loop[3] - getitem_14: "Sym(u9)" = while_loop[4] + getitem_10: "Sym(u10)" = while_loop[0] + getitem_11: "Sym(u11)" = while_loop[1] + getitem_12: "Sym(u12)" = while_loop[2] + getitem_13: "Sym(u13)" = while_loop[3] + getitem_14: "Sym(u14)" = while_loop[4] out_x: "f32[s77, s27]" = while_loop[5]; while_loop = None - add: "Sym(u7 + 1)" = getitem_12 + 1 - add_1: "Sym(u8 + 1)" = getitem_13 + 1 - add_2: "Sym(u9 + 1)" = getitem_14 + 1 + add: "Sym(u12 + 1)" = getitem_12 + 1 + add_1: "Sym(u13 + 1)" = getitem_13 + 1 + add_2: "Sym(u14 + 1)" = getitem_14 + 1 add_3: "f32[s77, s27]" = getitem_12 + out_x; getitem_12 = None add_4: "f32[s77, s27]" = getitem_13 + out_x; getitem_13 = None @@ -8077,7 +8071,7 @@ def forward(self, s77: "Sym(s77)", s27: "Sym(s27)", L_x_: "f32[s77, s27]"): return (getitem_10, getitem_11, add, add_1, add_2, add_3, add_4, add_5, out_x) class cond_fn_0(torch.nn.Module): - def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): + def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child_1: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): s27_1 = s27 s77_1 = s77 @@ -8088,19 +8082,19 @@ def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unba return lt class body_fn_0(torch.nn.Module): - def forward(self, unbacked_symint: "Sym(u0)", unbacked_symint_0: "Sym(u1)", unbacked_symint_1: "Sym(u2)", unbacked_symint_2: "Sym(u3)", unbacked_symint_3: "Sym(u4)", child: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): + def forward(self, unbacked_symint_4: "Sym(u5)", unbacked_symint_5: "Sym(u6)", unbacked_symint_6: "Sym(u7)", unbacked_symint_7: "Sym(u8)", unbacked_symint_8: "Sym(u9)", child_2: "f32[s77, s27]", s27: "Sym(s27)", s77: "Sym(s77)"): s27_1 = s27 s77_1 = s77 - add: "Sym(u0 + 1)" = unbacked_symint + 1; unbacked_symint = None - add_1: "Sym(u1 + 1)" = unbacked_symint_0 + 1; unbacked_symint_0 = None + add: "Sym(u5 + 1)" = unbacked_symint_4 + 1; unbacked_symint_4 = None + add_1: "Sym(u6 + 1)" = unbacked_symint_5 + 1; unbacked_symint_5 = None - add_2: "Sym(u2 + 1)" = unbacked_symint_1 + 1; unbacked_symint_1 = None - add_3: "Sym(u3 + 1)" = unbacked_symint_2 + 1; unbacked_symint_2 = None - add_4: "Sym(u4 + 1)" = unbacked_symint_3 + 1; unbacked_symint_3 = None + add_2: "Sym(u7 + 1)" = unbacked_symint_6 + 1; unbacked_symint_6 = None + add_3: "Sym(u8 + 1)" = unbacked_symint_7 + 1; unbacked_symint_7 = None + add_4: "Sym(u9 + 1)" = unbacked_symint_8 + 1; unbacked_symint_8 = None - child_1: "f32[s77, s27]" = child + 1; child = None - return (add, add_1, add_2, add_3, add_4, child_1) + child: "f32[s77, s27]" = child_2 + 1; child_2 = None + return (add, add_1, add_2, add_3, add_4, child) """, # noqa: B950 ) @@ -8196,7 +8190,6 @@ def _check_export_ret_graph_str(self, fn, args, dynamic_shapes=None) -> str: return normalize_gm(non_strict_ep.module().print_readable(print_output=False)) @skipIfTorchDynamo("Skip because dynamo cannot trace torch.export.") - @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_cond_eager_run_with_item(self): class M(torch.nn.Module): def forward(self, a, b1, b2, c): diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 6ba61a6c1d0d..0f893201733d 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -4152,7 +4152,7 @@ def test(): with subtest_ctx(self), skip_xfail_ctx(self): args = (sample_input.input,) + sample_input.args if not any(isinstance(arg, torch.Tensor) for arg in args): - # Atleast one tensor required for vmap. + # At least one tensor required for vmap. continue kwargs = sample_input.kwargs is_batch_norm_and_training = is_batch_norm_training(op.name, kwargs) @@ -4230,7 +4230,7 @@ def sample_vmap_out_dim_numpy_split_copy_with_int( xfail("as_strided_copy"), xfail( "as_strided_scatter" - ), # no batching rule implemented, default doesnt work + ), # no batching rule implemented, default doesn't work skip( "new_empty_strided" ), # empty tensor data is garbage so it's hard to make comparisons with it diff --git a/test/fx/test_fx_traceback.py b/test/fx/test_fx_traceback.py index 6306daa571bd..05369d17078b 100644 --- a/test/fx/test_fx_traceback.py +++ b/test/fx/test_fx_traceback.py @@ -2,6 +2,7 @@ import torch from torch._inductor.compile_fx import aot_export_module +from torch.export import default_decompositions from torch.fx.traceback import get_graph_provenance_json, NodeSource, NodeSourceAction from torch.testing._internal.common_utils import TestCase @@ -31,6 +32,8 @@ def test_node_source(self): dummy_source_dict, ) + self.assertEqual(node_source, NodeSource._from_dict(node_source.to_dict())) + # Dummy node node = torch.fx.Node( graph=torch.fx.Graph(), @@ -64,6 +67,62 @@ def test_node_source(self): }, ) + # Test two node sources are same + node_source1 = NodeSource( + node=None, pass_name="test_pass", action=NodeSourceAction.CREATE + ) + node_source2 = NodeSource( + node=None, pass_name="test_pass", action=NodeSourceAction.CREATE + ) + self.assertEqual(node_source1, node_source2) + + # Test hash function - equivalent objects should have same hash + self.assertEqual(hash(node_source1), hash(node_source2)) + + # Test two node sources are not same + node_source3 = NodeSource( + node=None, pass_name="test_pass_1", action=NodeSourceAction.CREATE + ) + node_source4 = NodeSource( + node=None, pass_name="test_pass_2", action=NodeSourceAction.CREATE + ) + self.assertNotEqual(node_source3, node_source4) + + # Test hash function - different objects should have different hash + self.assertNotEqual(hash(node_source3), hash(node_source4)) + + # Test that equivalent NodeSource objects can be used in sets and dicts + node_set = {node_source1, node_source2} + self.assertEqual(len(node_set), 1) # Should only contain one unique element + + node_dict = {node_source1: "value1", node_source2: "value2"} + self.assertEqual(len(node_dict), 1) # Should only contain one key + self.assertEqual(node_dict[node_source1], "value2") # Last value should win + + # Test with more complex NodeSource objects + node_source_with_node = NodeSource( + node=node, pass_name="test_pass", action=NodeSourceAction.CREATE + ) + node_source_with_node_copy = NodeSource( + node=node, pass_name="test_pass", action=NodeSourceAction.CREATE + ) + + # These should be equal and have same hash + self.assertEqual(node_source_with_node, node_source_with_node_copy) + self.assertEqual(hash(node_source_with_node), hash(node_source_with_node_copy)) + + # Test with different actions + node_source_replace = NodeSource( + node=None, pass_name="test_pass", action=NodeSourceAction.REPLACE + ) + node_source_create = NodeSource( + node=None, pass_name="test_pass", action=NodeSourceAction.CREATE + ) + + # These should be different and have different hashes + self.assertNotEqual(node_source_replace, node_source_create) + self.assertNotEqual(hash(node_source_replace), hash(node_source_create)) + def test_graph_provenance(self): def check_node_source(node_source_dict, name, pass_name, action): self.assertEqual(node_source_dict["name"], name) @@ -95,6 +154,57 @@ def forward(self, x): model = Model() example_inputs = (torch.randn(8, 10),) ep = torch.export.export(model, example_inputs, strict=True) + + decomposed_ep = ep.run_decompositions(default_decompositions()) + # node decomposed from same ancestor node should have same from_node info + for node in decomposed_ep.graph.nodes: + if node.op not in {"placeholder", "output"}: + assert "from_node" in node.meta + + node_name_to_from_node = { + node.name: node.meta["from_node"] + for node in decomposed_ep.graph.nodes + if node.op not in {"placeholder", "output"} + } + same_ancestor_nodes = { + "permute": "addmm", + "addmm": "permute", + "permute_1": "addmm_1", + "addmm_1": "permute_1", + } + + for node_name_1 in node_name_to_from_node: + for node_name_2 in node_name_to_from_node: + if node_name_2 in { + node_name_1, + same_ancestor_nodes[node_name_1] + if node_name_1 in same_ancestor_nodes + else None, + }: + self.assertEqual( + node_name_to_from_node[node_name_1], + node_name_to_from_node[node_name_2], + ) + self.assertEqual( + [ + NodeSource._from_dict(ns.to_dict()) + for ns in node_name_to_from_node[node_name_1] + ], + node_name_to_from_node[node_name_2], + ) + else: + self.assertNotEqual( + node_name_to_from_node[node_name_1], + node_name_to_from_node[node_name_2], + ) + self.assertNotEqual( + [ + NodeSource._from_dict(ns.to_dict()) + for ns in node_name_to_from_node[node_name_1] + ], + node_name_to_from_node[node_name_2], + ) + gm = ep.module() provenance = get_graph_provenance_json(gm.graph) self.assertEqual( diff --git a/test/fx/test_fx_xform_observer.py b/test/fx/test_fx_xform_observer.py index 2517439d9fe3..10577712196b 100644 --- a/test/fx/test_fx_xform_observer.py +++ b/test/fx/test_fx_xform_observer.py @@ -55,7 +55,7 @@ def replacement(x): ) ) - @torch._inductor.config.patch("trace.enabled", True) + @torch._inductor.config.patch("trace.provenance_tracking", True) def test_graph_transform_observer_node_tracking(self): class M(torch.nn.Module): def forward(self, x): @@ -156,7 +156,7 @@ def forward(self, x): [NodeSourceAction.REPLACE, NodeSourceAction.CREATE], ) - @torch._inductor.config.patch("trace.enabled", True) + @torch._inductor.config.patch("trace.provenance_tracking", True) def test_graph_transform_observer_deepcopy(self): class SimpleLinearModel(torch.nn.Module): def forward(self, x): diff --git a/test/fx/test_lazy_graph_module.py b/test/fx/test_lazy_graph_module.py index 6404b587d870..a17bcb9151de 100644 --- a/test/fx/test_lazy_graph_module.py +++ b/test/fx/test_lazy_graph_module.py @@ -69,7 +69,7 @@ def f(x): def test_needs_recompile(self): """ - Make sure needs_recompile() return the corrent state. + Make sure needs_recompile() return the correct state. """ def f(x): @@ -141,7 +141,7 @@ def f(x): self.assertTrue(isinstance(gm2, _LazyGraphModule)) self.assertTrue(gm2._needs_recompile()) - # make_fx will cal foward method of gm. That clears the _needs_recompile() + # make_fx will cal forward method of gm. That clears the _needs_recompile() # flag. self.assertFalse(gm._needs_recompile()) @@ -175,7 +175,7 @@ def f(x): def test_save_lazy_foward(self): """ - Save the lazy forward method and call it repeatly. Make sure we + Save the lazy forward method and call it repeatedly. Make sure we don't recompile for each such call. """ diff --git a/test/fx/test_partitioner_order.py b/test/fx/test_partitioner_order.py index ebe40f471e62..ab50b59fb96b 100644 --- a/test/fx/test_partitioner_order.py +++ b/test/fx/test_partitioner_order.py @@ -36,9 +36,9 @@ class TestPartitionerOrder(TestCase): def test_partitioner_order(self): m = AddModule() traced_m = torch.fx.symbolic_trace(m) - partions = DummyPartitioner(traced_m).propose_partitions() - partion_nodes = [list(partition.nodes) for partition in partions] - node_order = [n.name for n in partion_nodes[0]] + partitions = DummyPartitioner(traced_m).propose_partitions() + partition_nodes = [list(partition.nodes) for partition in partitions] + node_order = [n.name for n in partition_nodes[0]] for _ in range(10): traced_m = torch.fx.symbolic_trace(m) new_partion = DummyPartitioner(traced_m).propose_partitions() diff --git a/test/fx/test_pass_infra.py b/test/fx/test_pass_infra.py index 195a4fad2ba3..47531e15040e 100644 --- a/test/fx/test_pass_infra.py +++ b/test/fx/test_pass_infra.py @@ -131,7 +131,7 @@ def check_bad_args(graph_module, i): def test_topological_sort(self): """ - Tests that passes are correctly ordered based on contraints. + Tests that passes are correctly ordered based on constraints. """ def pass0(x): diff --git a/test/higher_order_ops/test_invoke_quant.py b/test/higher_order_ops/test_invoke_quant.py index 55dbad003db5..7796a9e4a168 100644 --- a/test/higher_order_ops/test_invoke_quant.py +++ b/test/higher_order_ops/test_invoke_quant.py @@ -186,7 +186,7 @@ def quant_matching(match: Match, *args, **kwargs): @skipIfXpu( msg="MM Triton template fusion for XPU not work because the fusion" - " can not speedup, unskip untill #146568 fixed." + " can not speedup, unskip until #146568 fixed." ) @requires_gpu() @config.patch(prologue_fusion=True) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index 052baebce337..c800eb78f905 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -21,6 +21,7 @@ normalize_gm, ) from torch._higher_order_ops.schema import find_hop_schema +from torch._inductor import config as inductor_config from torch._inductor.pattern_matcher import ( CallFunctionVarArgs, PatternMatcherPass, @@ -619,6 +620,7 @@ def fn(x, y): self.assertEqual(ref, res) res.sum().backward() + @inductor_config.patch("fx_graph_cache", False) def test_dropout_checks_joint_graph(self): # `dropout` tests that joint graph passes (not just partitioner) is ran # on the hop graphs. Inductor rng functionalization happens in the joint @@ -675,9 +677,9 @@ def forward(self, primals_0: "f32[8]"): sin: "f32[8]" = torch.ops.aten.sin.default(primals_0) inductor_seeds_default: "i64[1]" = torch.ops.prims.inductor_seeds.default(1, device(type='cpu')) + inductor_lookup_seed_default: "i64[]" = torch.ops.prims.inductor_lookup_seed.default(inductor_seeds_default, 0); inductor_seeds_default = None inductor_random_default: "f32[8]" = torch.ops.prims.inductor_random.default([8], inductor_lookup_seed_default, 'rand'); inductor_lookup_seed_default = None - gt: "b8[8]" = torch.ops.aten.gt.Scalar(inductor_random_default, 0.5); inductor_random_default = None mul: "f32[8]" = torch.ops.aten.mul.Tensor(gt, sin); sin = None mul_1: "f32[8]" = torch.ops.aten.mul.Tensor(mul, 2.0); mul = None @@ -690,6 +692,7 @@ def forward(self, primals_0: "f32[8]"): """, ) + @inductor_config.patch("fx_graph_cache", False) def test_dropout_checks_joint_graph_inference(self): # Checks that joint graph results in inductor seeds for just the inference graph @nested_compile_region @@ -719,9 +722,9 @@ def forward(self, arg0_1: "f32[8]"): class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[8]"): inductor_seeds_default: "i64[1]" = torch.ops.prims.inductor_seeds.default(1, device(type='cpu')) + inductor_lookup_seed_default: "i64[]" = torch.ops.prims.inductor_lookup_seed.default(inductor_seeds_default, 0); inductor_seeds_default = None inductor_random_default: "f32[8]" = torch.ops.prims.inductor_random.default([8], inductor_lookup_seed_default, 'rand'); inductor_lookup_seed_default = None - gt: "b8[8]" = torch.ops.aten.gt.Scalar(inductor_random_default, 0.5); inductor_random_default = None sin: "f32[8]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None mul: "f32[8]" = torch.ops.aten.mul.Tensor(gt, sin); gt = sin = None @@ -917,6 +920,7 @@ def forward(self, a: "f32[8]", l_y_: "f32[8]"): """, ) + @inductor_config.patch("fx_graph_cache", False) def test_view_to_reshape(self): @nested_compile_region def gn(x): @@ -1079,7 +1083,7 @@ def fn(x, y): fake_prop_count = 0 - def _mock_invoke_subgraph(mode, subgraph, identifer, *operands): + def _mock_invoke_subgraph(mode, subgraph, identifier, *operands): nonlocal fake_prop_count fake_prop_count += 1 return (operands[0].clone(),) @@ -2077,7 +2081,7 @@ def fn(x, y): # NOTE THAT THIS TEST DOES NOT REALLY WORK # We wanted one invoke_subgraph called twice, but because of - # constant_args_idx changing in the grpah, the graph equivalence fails + # constant_args_idx changing in the graph, the graph equivalence fails if not TEST_WITH_CROSSREF: self.assertExpectedInline( diff --git a/test/inductor/custom_inductor_config.py b/test/inductor/custom_inductor_config.py new file mode 100644 index 000000000000..e29430728f94 --- /dev/null +++ b/test/inductor/custom_inductor_config.py @@ -0,0 +1,15 @@ +# Owner(s): ["module: inductor"] + +# This module is used in test_codecache.py to verify the correctness +# of FXGraphHashDetails when a custom inductor backend registers its own +# config object + +import sys + +from torch.utils._config_module import install_config_module + + +enable_optimisation: bool = False + +# adds patch, save_config, etc +install_config_module(sys.modules[__name__]) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 7950c3672cf4..a96a11da93bd 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -35,7 +35,11 @@ from torch.export.pt2_archive._package import load_pt2 from torch.testing import FileCheck from torch.testing._internal import common_utils -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM80OrLater +from torch.testing._internal.common_cuda import ( + _get_torch_cuda_version, + PLATFORM_SUPPORTS_FP8, + SM80OrLater, +) from torch.testing._internal.common_device_type import ( _has_sufficient_memory, skipCUDAIf, @@ -51,9 +55,12 @@ IS_FBCODE, IS_MACOS, IS_WINDOWS, + MACOS_VERSION, parametrize, + skipIfMPS, skipIfRocm, skipIfXpu, + TEST_MPS, TEST_WITH_ROCM, ) from torch.testing._internal.custom_tensor import CustomTensorPlainOut @@ -169,7 +176,9 @@ def forward(self, x, y): _, code = run_and_get_cpp_code( AOTIRunnerUtil.compile, model, example_inputs ) - if self.device == GPU_TYPE: + if self.device == "mps": + FileCheck().check("getKernelFunction(").run(code) + elif self.device == GPU_TYPE: FileCheck().check("launchKernel(").run(code) if config.aot_inductor.embed_kernel_binary: # Not expect to see launchKernel("CUBIN_FILE_NAME" @@ -184,10 +193,14 @@ def forward(self, x, y): IS_FBCODE, "toolchain doesn't support ptx to fatbin", ) + @skipIfMPS @skipIfRocm # Skip embed_kernel_binary == True for now as it shows random # failure on CI @common_utils.parametrize("embed_kernel_binary", [False]) + @unittest.skipIf( + _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" + ) def test_simple_multi_arch(self, embed_kernel_binary): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU_TYPE") @@ -425,6 +438,10 @@ def forward(self, y): ep, inductor_configs={"aot_inductor.use_runtime_constant_folding": True} ) + @unittest.skipIf( + TEST_MPS and MACOS_VERSION < 14.0, + "Compilation error", + ) def test_aot_inductor_consts_cpp_build(self): class Model(torch.nn.Module): def __init__(self, device) -> None: @@ -781,6 +798,10 @@ def forward(self, a, b): inp = (torch.ones(3, device=self.device), torch.ones(3, device=self.device)) self.check_model(M(), inp) + @unittest.skipIf( + TEST_MPS and MACOS_VERSION < 14.0, + "MPS BFloat16 is only supported on MacOS 14+", + ) def test_empty_cat_dtype_promotion(self): class Foo(torch.nn.Module): def forward(self, x, y): @@ -1434,6 +1455,22 @@ def forward(self, x): with config.patch({"aot_inductor.use_runtime_constant_folding": True}): self.check_model(Model(self.device), example_inputs) + @skipIfNoFBGEMM + def test_quantized_linear_bias_none(self): + class Model(torch.nn.Module): + def __init__(self, device): + super().__init__() + self.weight = torch.randn(10, 10, device=device) + + def forward(self, x): + return torch.ops.quantized.linear_dynamic_fp16_unpacked_weight( + x, self.weight, None + ) + + example_inputs = (torch.randn(10, 10, device=self.device),) + with config.patch({"aot_inductor.use_runtime_constant_folding": True}): + self.check_model(Model(self.device), example_inputs) + @skipIfNoFBGEMM def test_quanatized_int8_linear(self): class Model(torch.nn.Module): @@ -1504,6 +1541,10 @@ def forward(self, x, y): ) self.check_model(Repro(), example_inputs) + @unittest.skipIf( + TEST_MPS and MACOS_VERSION < 14.0, + "bfloat16 is only supported on MacOS 14+", + ) def test_size_with_unbacked_add_expr(self): # Tests AOTI autotuning to make sure the correct input tensor sizes # are generated for sizes that include an expr such as s0 + u0. @@ -1759,7 +1800,7 @@ def forward(self, x): Foo(user_float_feature_idx, self.device), example_inputs, strict=False ).run_decompositions() gm = ep.module() - self.check_model(gm, example_inputs) + self.check_model(gm.to(self.device), example_inputs) def test_large_grid(self): if self.device != GPU_TYPE: @@ -2427,6 +2468,7 @@ def forward(self, x): self.check_model(converted_model, example_inputs) + @skipIfMPS def test_fallback_mem_leak_fix(self): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU") @@ -2471,6 +2513,7 @@ def forward(self, x, y, idx): torch.testing.assert_close(actual, expected) @requires_multigpu() + @skipIfMPS def test_replicate_on_devices(self): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU") @@ -2510,6 +2553,7 @@ def forward(self, x, y): self.assertTrue(same(result_cpu, result_gpu.cpu())) @requires_multigpu() + @skipIfMPS def test_on_gpu_device1(self): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU") @@ -2659,7 +2703,11 @@ def forward(self, x, y): model, example_inputs, atol=1e-4, rtol=1e-4 ) # 1e-4 is the tol value used in pytorch/torch/_dynamo/utils.py - if self.device == GPU_TYPE: + if self.device == "mps": + self.code_check_count( + model, example_inputs, '.getKernelFunction("generated_kernel")', 1 + ) + elif self.device == GPU_TYPE: self.code_check_count( model, example_inputs, "triton_poi_fused_sin_0 = loadKernel(", 1 ) @@ -3058,10 +3106,9 @@ def forward(self, x): # Call eval() here so that batch_norm won't update the running stats # Use float64 to avoid numeric difference failure - model = Model().to(device=self.device, dtype=torch.float64).eval() - example_inputs = ( - torch.randn(4, 3, 64, 64, device=self.device, dtype=torch.float64), - ) + dtype = torch.float32 if self.device == "mps" else torch.float64 + model = Model().to(device=self.device, dtype=dtype).eval() + example_inputs = (torch.randn(4, 3, 64, 64, device=self.device, dtype=dtype),) self.check_model(model, example_inputs) def test_triton_next_power_of_2(self): @@ -3122,6 +3169,7 @@ def forward(self, a, b, ranks): torch._dynamo.mark_dynamic(example_inputs[1], 0) self.check_model(Model(), example_inputs) + @skipIfMPS @common_utils.parametrize("grid_type", [1, 2, 3]) @common_utils.parametrize("num_dims", [1, 2]) @common_utils.parametrize("dynamic", [False, True]) @@ -4153,6 +4201,7 @@ def forward(self, x, y): expected = Model()(*example_inputs) torch.testing.assert_close(actual, expected) + @skipIfMPS @torch._dynamo.config.patch(capture_scalar_outputs=True) @common_utils.parametrize("dynamic", [False, True]) @common_utils.parametrize("autotuning", [False, True]) @@ -4329,24 +4378,13 @@ def forward(self, x, i1, i2, y): @patch.dict(os.environ, {"AOTI_RUNTIME_CHECK_INPUTS": "1"}) def test_runtime_checks(self): class Model(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - if SM80OrLater: - - def forward(self, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9): - return (x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) + def forward(self, inputs): + return list(inputs.values()) - else: - - def forward(self, x0, x1, x2, x4, x5, x6, x7, x8, x9): - return (x0, x1, x2, x4, x5, x6, x7, x8, x9) - - inputs = [] + inputs = {} dtypes = [ torch.float16, torch.float32, - torch.float64, torch.bool, torch.int8, torch.int16, @@ -4354,60 +4392,75 @@ def forward(self, x0, x1, x2, x4, x5, x6, x7, x8, x9): torch.int64, torch.uint8, ] + + if not TEST_MPS: + dtypes.append(torch.float64) if SM80OrLater: dtypes.append(torch.bfloat16) + for dtype in dtypes: - inputs.append(torch.ones(4, 8, 10, dtype=dtype, device=self.device)) + inputs[f"x_{str(dtype)}"] = torch.ones( + 4, 8, 10, dtype=dtype, device=self.device + ) dim0 = Dim("s0", min=2, max=1024) dim1 = Dim("s1", min=2, max=512) dim2 = Dim("s2", min=2, max=128) dynamic_shapes = { - "x0": {0: dim0}, - "x1": {0: dim0}, - "x2": {0: dim0}, - "x4": {1: dim1}, - "x5": {1: dim1}, - "x6": {}, - "x7": {2: dim2}, - "x8": {2: dim2}, - "x9": {2: dim2}, + "x_torch.float16": {0: dim0}, + "x_torch.float32": {0: dim0}, + "x_torch.bool": {1: dim1}, + "x_torch.int8": {1: dim1}, + "x_torch.int16": {}, + "x_torch.int32": {2: dim2}, + "x_torch.int64": {2: dim2}, + "x_torch.uint8": {2: dim2}, } + if not TEST_MPS: + dynamic_shapes["x_torch.float64"] = {0: dim0} if SM80OrLater: - dynamic_shapes["x3"] = {1: dim1} + dynamic_shapes["x_torch.bfloat16"] = {1: dim1} m = Model() - inputs = tuple(inputs) + inputs = (inputs,) + dynamic_shapes = (dynamic_shapes,) with torch.no_grad(): so_path = AOTIRunnerUtil.legacy_compile( m, inputs, dynamic_shapes=dynamic_shapes ) + + # Expected results for the following checks: + # ("unmatched dtype", "unmatched dim value at", "dim value is too", "unmatched stride value at") + if SM80OrLater: + # 10 dynamic dims + expected_results = (10, 21, 18, 21) + elif TEST_MPS: + # 8 dynamic dims + expected_results = (8, 17, 14, 16) + else: + # 9 dynamic dims + expected_results = (9, 19, 16, 19) + with open(os.path.splitext(so_path)[0] + ".cpp") as cpp: src_code = cpp.read() FileCheck().check_count( "unmatched dtype", - 10 if SM80OrLater else 9, + expected_results[0], exactly=True, ).run(src_code) FileCheck().check_count( "unmatched dim value at", - 21 - if SM80OrLater - else 19, # we have 9 dynamic dims for which we generate different checks + expected_results[1], exactly=True, ).run(src_code) FileCheck().check_count( "dim value is too", - 18 - if SM80OrLater - else 16, # we have 9 dynamic dims for which we generate two checks + expected_results[2], exactly=True, ).run(src_code) FileCheck().check_count( "unmatched stride value at", - 21 - if SM80OrLater - else 19, # we have 9 symbolic strides for which we don't generate checks + expected_results[3], exactly=True, ).run(src_code) @@ -4671,6 +4724,10 @@ def forward(self, w, i, o): ) self.check_model(Model(), example_inputs) + @unittest.skipIf( + TEST_MPS and MACOS_VERSION < 14.0, + "FFT operations are only supported on MacOS 14+", + ) def test_fft_c2c(self): class Model(torch.nn.Module): def forward(self, x): @@ -4837,16 +4894,15 @@ def forward(self, a): a = torch.randn(batch, M, K, device=self.device) example_inputs = (a,) - kernel_calls = ( - [ + if self.device == "mps": + kernel_calls = [("aoti_torch_mps_addmm_out", 2)] + elif self.device == GPU_TYPE: + kernel_calls = [ ("triton_poi_fused_0", 1), (f"aoti_torch_{GPU_TYPE}_addmm_out", 2), ] - if self.device == GPU_TYPE - else [ - ("aoti_torch_cpu_addmm_out", 2), - ] - ) + else: + kernel_calls = [("aoti_torch_cpu_addmm_out", 2)] # test default debug printing all tensor values codegen with config.patch({"aot_inductor.debug_intermediate_value_printer": "2"}): @@ -5637,6 +5693,53 @@ def runner_call(*args, **kwargs): ) self.assertEqual(new_expected, new_output) + def test_update_constant_buffer_simple(self): + class Model(torch.nn.Module): + def __init__(self, device): + super().__init__() + self.weight = torch.randn((3, 3), device=device) + + def forward(self, a): + return a + self.weight + + model = Model(self.device) + a = torch.randn((3, 3), device=self.device) + example_inputs = (a,) + + with torch.no_grad(), config.patch({"always_keep_tensor_constants": True}): + so_path = AOTIRunnerUtil.legacy_compile( + model=model, + example_inputs=example_inputs, + ) + + runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path) + + # Let's check whether the model has correct constant name mapping. + expected_original_fqns = { + "L__self___weight": "L__self___weight", + } + self.assertEqual( + expected_original_fqns, runner.get_constant_names_to_original_fqns() + ) + + test_inputs = torch.randn((3, 3), device=self.device) + new_weight = torch.randn((3, 3), device=self.device) + model.weight = new_weight + attach_weights = {"L__self___weight": new_weight} + runner.update_constant_buffer(attach_weights, False, False, False) + expected = model(test_inputs) + + def runner_call(*args, **kwargs): + call_spec = runner.get_call_spec() # type: ignore[attr-defined] + out_spec = pytree.treespec_loads(call_spec[1]) + flat_inputs = pytree.tree_flatten((args, kwargs))[0] + flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)] + flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined] + return pytree.tree_unflatten(flat_outputs, out_spec) + + output = runner_call(test_inputs) + self.assertEqual(expected, output) + def test_update_inactive_constant_buffer(self): class Model(torch.nn.Module): def __init__(self, n, k, device): @@ -5869,6 +5972,31 @@ def runner_call(*args, **kwargs): ) self.assertEqual(new_expected, new_output) + new_weights = { + "L__self___weight": torch.randn(N, K, device=self.device), + "L__self___bias": torch.randn(N, device=self.device), + } + + runner.update_constant_buffer(new_weights, True, False, True) + runner.swap_constant_buffer() + + model.weight = torch.nn.Parameter(new_weights["L__self___weight"]) + model.bias = torch.nn.Parameter(new_weights["L__self___bias"]) + + updated_state_dict = { + "weight": torch.ones_like(model.weight), + "bias": torch.zeros_like(model.bias), + } + + model.load_state_dict(updated_state_dict) + + new_output = runner_call(test_inputs) + expected_output = model(test_inputs) + torch.testing.assert_close(new_output, expected_output) + + with self.assertRaises(AssertionError): + torch.testing.assert_close(new_expected, new_output) + def test_cond_share_predicte(self): class Model(torch.nn.Module): def forward(self, predicate, x): @@ -5982,13 +6110,17 @@ def forward(self, x, y): ) @unittest.skipIf(IS_FBCODE, "Not runnable in fbcode") + @unittest.skipIf( + TEST_MPS and MACOS_VERSION < 14.0, + "FFT operations are only supported on MacOS 14+", + ) def test_stft(self): N_FFT = 400 HOP_LENGTH = 160 class Model(torch.nn.Module): def forward(self, x): - window = torch.hann_window(N_FFT).to(x.device) + window = torch.hann_window(N_FFT, device=x.device) stft = torch.stft( x, N_FFT, HOP_LENGTH, window=window, return_complex=True ) @@ -6072,6 +6204,7 @@ def forward(self, x, y): ) self.check_model(Model(), example_inputs) + @skipIfMPS @skipIfXpu( msg="aten::convert_weight_to_int4pack is not currently implemented for XPU" ) @@ -6611,6 +6744,13 @@ def fail_cpu(is_skip=False): ) +def fail_mps(is_skip=False): + return TestFailure( + ("mps",), + is_skip=is_skip, + ) + + def fail_gpu(suffixes: tuple[str, ...], is_skip=False): return TestFailure( suffixes, @@ -6629,12 +6769,99 @@ def fail_gpu(suffixes: tuple[str, ...], is_skip=False): # quantized unsupported for GPU "test_quantized_linear": fail_gpu(("cuda", "xpu")), "test_quanatized_int8_linear": fail_gpu(("cuda", "xpu")), + "test_quantized_linear_bias_none": fail_gpu(("cuda", "xpu")), # No scaled_dot_product_efficient_attention implementation for XPU yet. "test_scaled_dot_product_efficient_attention": fail_gpu(("xpu",)), # No fft implementation for XPU yet. "test_fft_c2c": fail_gpu(("xpu",), is_skip=True), } +MPS_TEST_FAILURES = { + # aten::_embedding_bag is not currently implemented for the MPS device. + "test_embedding_bag": fail_mps(), + # aten::_embedding_bag is not currently implemented for the MPS device. + "test_misc_1_max_autotune_False": fail_mps(), + "test_misc_1_max_autotune_True": fail_mps(), + # aten::_scaled_dot_product_efficient_attention is not currently implemented for the MPS device. + "test_scaled_dot_product_efficient_attention": fail_mps(), + # aten::_int_mm is not implemented for MPS backend + "test__int_mm": fail_mps(), + # MPS doesn't support float64 + "test_while_loop_with_conv_dynamic_True": fail_mps(), + "test_while_loop_with_conv_dynamic_False": fail_mps(), + # MPS doesn't support float8 + "test_fp8": fail_mps(), + "test_fp8_view_of_param": fail_mps(), + # unsupported operator: aten._scaled_dot_product_attention_math_for_mps.default + "test_issue_140766": fail_mps(), + # cannot initialize a parameter of type 'double' with an rvalue of type 'std::nullptr_t' + "test_fallback_kernel_with_symexpr_output": fail_mps(), + # while-loop subgraph calls same kernel as outside. need to figure out how to + # either (1) tell outside to initialize a new kernel or (2) generate + # subgraph as a separate function, which would(?) cause (1) to happen automatically. + "test_while_loop_nested": fail_mps(), + # correctness issue + "test_index_put_with_none_index": fail_mps(), + # Dynamism + "test_shifted_constraint_ranges": fail_mps(), + "test_while_loop_with_sym_expr_cond_dynamic_True": fail_mps(), + "test_while_loop_with_unbacked_symint_closure_dynamic_True": fail_mps(), + "test_cond_mismatched_branch_output_dynamic_True": fail_mps(), + "test_cond_unbacked_symint_closure_dynamic_True": fail_mps(), + "test_cond_non_tensor_predicates_dynamic_True": fail_mps(), + "test_zero_grid_with_unbacked_symbols": fail_mps(), + "test_reuse_kernel_dynamic": fail_mps(is_skip=True), + "test_cond_with_parameters": fail_mps(is_skip=True), + "test_cond_share_predicte": fail_mps(is_skip=True), + # Error device may not be nil + "test_zero_size_weight": fail_mps(is_skip=True), + # RuntimeError: Cannot compare two tensors on different devices. Got: cpu and mps:0 + "test_aoti_constant_tensor_name_collision": fail_mps(is_skip=True), + # MPS doesn't support triton + "test_autotuning_args_reuse": fail_mps(), + "test_triton_autotuning": fail_mps(), + "test_triton_dynamic_launcher_grid": fail_mps(), + "test_triton_dynamic_launcher_grid_infer_from_tensor": fail_mps(), + "test_triton_kernel_on_device_tma_dynamic_False_tma_version_new": fail_mps(), + "test_triton_kernel_on_device_tma_dynamic_False_tma_version_old": fail_mps(), + "test_triton_kernel_on_device_tma_dynamic_True_tma_version_new": fail_mps(), + "test_triton_kernel_on_device_tma_dynamic_True_tma_version_old": fail_mps(), + "test_size_with_unbacked_add_expr_transitive": fail_mps(), + "test_size_with_unbacked_add_and_mul_expr": fail_mps(), + "test_triton_next_power_of_2": fail_mps(), + "test_sympy_cpp_printer_min_max_minmax0": fail_mps(), + "test_sympy_cpp_printer_min_max_minmax1": fail_mps(), + "test_triton_kernel_dynamic_shape_with_div": fail_mps(), + "test_triton_kernel_reinterpret_view": fail_mps(), + "test_triton_kernel_tma_descriptor_1d_dynamic_False_tma_version_new_mps": fail_mps(), + "test_triton_kernel_tma_descriptor_1d_dynamic_False_tma_version_old_mps": fail_mps(), + "test_triton_kernel_tma_descriptor_1d_dynamic_True_tma_version_new_mps": fail_mps(), + "test_triton_kernel_tma_descriptor_1d_dynamic_True_tma_version_old_mps": fail_mps(), + "test_triton_kernel_tma_descriptor_2d_dynamic_False_tma_version_new_mps": fail_mps(), + "test_triton_kernel_tma_descriptor_2d_dynamic_False_tma_version_old_mps": fail_mps(), + "test_triton_kernel_tma_descriptor_2d_dynamic_True_tma_version_new_mps": fail_mps(), + "test_triton_kernel_tma_descriptor_2d_dynamic_True_tma_version_old_mps": fail_mps(), + "test_triton_kernel_sympy_expr_arg": fail_mps(), + "test_triton_kernel_sympy_fn_like_arg": fail_mps(), + "test_triton_kernel_with_none_input": fail_mps(), + "test_triton_kernel_equal_to_1_arg": fail_mps(), + "test_triton_kernel_with_none_inputs_and_equal_to_1_arg": fail_mps(), + "test_triton_kernel_equal_to_1_float_arg_dynamic_True": fail_mps(), + "test_triton_kernel_equal_to_1_float_arg_dynamic_False": fail_mps(), + "test_triton_kernel_weird_param_order": fail_mps(), + "test_triton_kernel_dynamic_grid": fail_mps(), + "test_repeated_user_defined_triton_kernel_embed_kernel_binary_False": fail_mps(), + "test_repeated_user_defined_triton_kernel_embed_kernel_binary_True": fail_mps(), + "test_triton_kernel_extern_kernel_arg": fail_mps(), + "test_triton_kernel_multi_output_arg": fail_mps(), + "test_triton_kernel_reinterpret_view_mem_leak": fail_mps(), + "test_triton_mutated_autotuning": fail_mps(), + "test_sym_i64_input_codegen": fail_mps(), + "test_none_args_aot_codegen": fail_mps(), + "test_aoti_debug_printer_sym_inputs": fail_mps(), + "test_aoti_debug_printer_user_defined_triton_kernel": fail_mps(), +} + class AOTInductorTestABICompatibleCpu(TestCase): device = "cpu" @@ -6672,6 +6899,26 @@ class AOTInductorTestABICompatibleGpu(TestCase): GPU_TEST_FAILURES, ) + +@unittest.skipIf(not torch.backends.mps.is_available(), "No MPS backend available") +class AOTInductorTestABICompatibleMps(TestCase): + device = "mps" + device_type = "mps" + check_model = check_model + check_model_with_multiple_inputs = check_model_with_multiple_inputs + code_check_count = code_check_count + allow_stack_allocation = False + use_minimal_arrayref_interface = False + + +copy_tests( + AOTInductorTestsTemplate, + AOTInductorTestABICompatibleMps, + "mps", + MPS_TEST_FAILURES, +) + + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_aot_inductor_custom_ops.py b/test/inductor/test_aot_inductor_custom_ops.py index 31de9ac4c71d..aa3c589b4546 100644 --- a/test/inductor/test_aot_inductor_custom_ops.py +++ b/test/inductor/test_aot_inductor_custom_ops.py @@ -416,6 +416,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): @skipIfXpu @skipIfRocm + @unittest.skipIf(IS_FBCODE, "unable to find library -laoti_custom_ops") def test_custom_op_square(self) -> None: class Model(torch.nn.Module): def forward(self, x): @@ -511,6 +512,7 @@ def fail_cuda(is_skip=False): # quantized unsupported for GPU "test_quantized_linear": fail_cuda(), "test_quanatized_int8_linear": fail_cuda(), + "test_quantized_linear_bias_none": fail_cuda(), } diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index 2f2b92168c6e..51343b6b1883 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -16,11 +16,17 @@ import torch from torch._inductor.codecache import get_kernel_bin_format -from torch._inductor.package import AOTICompiledModel, load_package, package_aoti +from torch._inductor.package import load_package, package_aoti from torch._inductor.test_case import TestCase from torch._inductor.utils import fresh_cache from torch.export import Dim -from torch.export.pt2_archive._package import load_pt2, load_weights_to_pt2_contents +from torch.export.experimental import _ExportPackage +from torch.export.pt2_archive._package import ( + AOTICompiledModel, + load_pt2, + load_weights_to_pt2_contents, +) +from torch.testing._internal.common_cuda import _get_torch_cuda_version from torch.testing._internal.common_utils import ( IS_FBCODE, skipIfRocm, @@ -30,20 +36,6 @@ from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU -try: - from test_static_linkage_utils import ( - get_static_linkage_main_cpp_file, - get_static_linkage_makelist_file_cpu, - get_static_linkage_makelist_file_cuda, - ) -except ImportError: - from .test_static_linkage_utils import ( - get_static_linkage_main_cpp_file, - get_static_linkage_makelist_file_cpu, - get_static_linkage_makelist_file_cuda, - ) - - def skipif(predicate: Callable[[str, bool], bool], reason: str): def decorator(func): @functools.wraps(func) @@ -152,6 +144,28 @@ def check_package_cpp_only(self: TestCase) -> None: if shutil.which("make") is None: raise unittest.SkipTest("make is not available") + def cmake_compile_and_run(self, base_dir): + custom_env = os.environ.copy() + custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent) + build_path = Path(base_dir) / "build" + build_path.mkdir() + subprocess.run( + ["cmake", ".."], + cwd=build_path, + env=custom_env, + check=True, + ) + subprocess.run(["make"], cwd=build_path, check=True) + result = subprocess.run( + ["./build/main"], + cwd=base_dir, + check=True, + capture_output=True, + text=True, + ) + + return result + def cmake_compile(self, model, example_inputs, options, tmp_dir): """ Exports model, compiles it using AOTInductor, extracts the @@ -249,6 +263,9 @@ def forward(self, x, y): self.check_model(Model(), example_inputs) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") + @unittest.skipIf( + _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" + ) @skipIfXpu # build system may be different def test_compile_after_package(self): self.check_package_cpp_only() @@ -294,6 +311,9 @@ def forward(self, x, y): actual = optimized(*example_inputs) self.assertTrue(torch.allclose(actual, expected)) + @unittest.skipIf( + _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" + ) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") @skipIfRocm # doesn't support multi-arch binary @skipIfXpu # doesn't support multi-arch binary @@ -338,6 +358,9 @@ def forward(self, x, y): actual = optimized(*example_inputs) self.assertTrue(torch.allclose(actual, expected)) + @unittest.skipIf( + _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" + ) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") @skipIfXpu # build system may be different def test_compile_after_package_static(self): @@ -396,10 +419,13 @@ def forward(self, x, y): with self.assertRaisesRegex(Exception, "Invalid AOTI model name"): self.cmake_compile(model, example_inputs, options, "") + @unittest.skipIf( + _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" + ) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") @skipIfRocm # doesn't support multi-arch binary @skipIfXpu # doesn't support multi-arch binary - def test_run_static_linkage_model(self): + def test_compile_with_exporter(self): self.check_package_cpp_only() class Model1(torch.nn.Module): @@ -410,64 +436,45 @@ class Model2(torch.nn.Module): def forward(self, x, y): return x - y + def default(*args, **kwargs): + return None + example_inputs = ( - torch.randn(10, 10, device=self.device), - torch.randn(10, 10, device=self.device), + torch.ones(3, 3).to(self.device), + torch.ones(3, 3).to(self.device), ) - model1 = Model1().to(self.device) - model2 = Model2().to(self.device) - - models = [model1, model2] + package = _ExportPackage() + m1 = Model1() + m2 = Model2() + exporter1 = package._exporter("Plus", m1)._define_overload("default", default) + exporter2 = package._exporter("Minus", m2)._define_overload("default", default) + exporter1(*example_inputs) + exporter2(*example_inputs) - i = 0 - model_names = ["Plus", "Minus"] - with ( - tempfile.TemporaryDirectory() as tmp_dir, - ): - for i in range(2): - model = models[i] - # TODO: should be done through _ExportPackage - ep = torch.export.export(model, example_inputs) - - package_path = torch._inductor.aoti_compile_and_package( - ep, - inductor_configs={ - "aot_inductor.compile_standalone": True, - "always_keep_tensor_constants": True, - "aot_inductor.model_name_for_generated_files": model_names[i], - }, + for package_example_inputs in [True, False]: + with ( + tempfile.TemporaryDirectory() as tmp_dir, + ): + package._compiled_and_package( + tmp_dir + "/package.pt2", True, package_example_inputs ) - with ( - zipfile.ZipFile(package_path, "r") as zip_ref, - ): - zip_ref.extractall(tmp_dir) - - file_str = get_static_linkage_main_cpp_file() - with open(Path(tmp_dir) / "main.cpp", "w") as f: - f.write(file_str) - - if self.device == GPU_TYPE: - cmake_file_str = get_static_linkage_makelist_file_cuda() - else: - cmake_file_str = get_static_linkage_makelist_file_cpu() - with open(Path(tmp_dir) / "CMakeLists.txt", "w") as f: - f.write(cmake_file_str) - - build_path = Path(tmp_dir) / "build" - build_path.mkdir() - custom_env = os.environ.copy() - custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent) - subprocess.run( - ["cmake", ".."], - cwd=build_path, - env=custom_env, - ) - subprocess.run(["make"], cwd=build_path, check=True) - subprocess.run( - ["./main", f"{tmp_dir}/", self.device], cwd=build_path, check=True - ) + # Test compiling generated files + result = self.cmake_compile_and_run(tmp_dir) + if package_example_inputs: + if self.device == GPU_TYPE: + self.assertEqual( + result.stdout, + "output_tensor1 2 2 2\n 2 2 2\n 2 2 2\n[ CUDAFloatType{3,3} ]\noutput_tensor2 0 0 0\n" + " 0 0 0\n 0 0 0\n[ CUDAFloatType{3,3} ]\n", + ) + else: + self.assertEqual( + result.stdout, + "output_tensor1 2 2 2\n 2 2 2\n 2 2 2\n[ CPUFloatType{3,3} ]\noutput_tensor2 0 0 0\n" + " 0 0 0\n 0 0 0\n[ CPUFloatType{3,3} ]\n", + ) def test_metadata(self): class Model(torch.nn.Module): diff --git a/test/inductor/test_aot_inductor_utils.py b/test/inductor/test_aot_inductor_utils.py index 9d25aa475601..a2706933d615 100644 --- a/test/inductor/test_aot_inductor_utils.py +++ b/test/inductor/test_aot_inductor_utils.py @@ -102,6 +102,8 @@ def legacy_load_runner(device, so_path: str) -> "AOTIModelContainerRunner": return torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) elif device == "xpu": return torch._C._aoti.AOTIModelContainerRunnerXpu(so_path, 1, device) + elif device == "mps": + return torch._C._aoti.AOTIModelContainerRunnerMps(so_path, 1) else: return torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py index da146acd6368..65df4912a41c 100644 --- a/test/inductor/test_auto_functionalize.py +++ b/test/inductor/test_auto_functionalize.py @@ -185,9 +185,15 @@ def f(x, y, z, n): post_grad_graphs, """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): - # No stacktrace found for following nodes + # Custom comment for test foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = foo_default = None return ()""", # noqa: B950 + ignore_comments=True, + ) + + # stack trace should be in post_grad_graph + self.assertTrue( + "code: torch.ops.mylib.foo(x, y, z, 2, n)" in post_grad_graphs, ) eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) @@ -328,10 +334,16 @@ def f(x, y, z, n): post_grad_graphs, """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): - # No stacktrace found for following nodes + # Custom comment for test foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1); \ arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None return ()""", + ignore_comments=True, + ) + + # stack trace should be in post_grad_graph + self.assertTrue( + "code: torch.ops.mylib.foo(x, y, z, 2, n)" in post_grad_graphs, ) eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) @@ -433,7 +445,7 @@ def run_aot_eager(self, f, orig_args, _dynamic=False): aot_eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) log_stream, ctx = logs_to_string( - "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" + "torch._functorch._aot_autograd.graph_capture", "aot_graphs" ) result = None diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 68626aafcc0d..93545ed93cc3 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -66,6 +66,12 @@ from torch.testing._internal.triton_utils import requires_cuda +try: + from . import custom_inductor_config +except ImportError: + import custom_inductor_config + + if HAS_TRITON: import triton # @manual @@ -1854,6 +1860,7 @@ def f(x): @config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_remote_cache": False}) @functorch_config.patch({"enable_autograd_cache": True}) + @functorch_config.patch({"autograd_cache_normalize_inputs": True}) def test_split_module(self): class Mod(torch.nn.Module): def forward(self, x, a0, a1, b0, b1, c0, c1): @@ -1900,6 +1907,14 @@ def t(): y = ca0(a0, x, a1) y = ca1(b0, y, b1) y = ca2(c0, y, c1) + self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2) + # TODO: split_module causes ca1 and ca2 to have different type annotations + # for the parameter x, so we can only AOTAutogradCache cache hit once instead of twice + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2) expected = Mod()(*example_inputs) self.assertEqual(y, expected) @@ -2454,6 +2469,50 @@ def uuid(self) -> Optional[Union[bytes, str]]: pickler.dumps(details3), ) + def test_hash_custom_backend_config(self): + """ + Test cache correctness when a custom inductor codegen config + is installed + """ + with patch_inductor_backend( + "cpu", custom_backend_config=custom_inductor_config + ): + gm = torch.fx.GraphModule({}, torch.fx.Graph()) + pickler = FxGraphCachePickler(gm) + details1 = FxGraphHashDetails(None, [], {}, []) + details2 = FxGraphHashDetails(None, [], {}, []) + self.assertEqual(pickler.dumps(details1), pickler.dumps(details2)) + + custom_inductor_config.enable_optimisation = True + details3 = FxGraphHashDetails(None, [], {}, []) + self.assertNotEqual(pickler.dumps(details2), pickler.dumps(details3)) + + torch._dynamo.reset() + counters.clear() + + custom_inductor_config.enable_optimisation = False + x = torch.zeros(32) + y = torch.zeros(32) + compiled_fn = torch.compile(torch.add) + + compiled_fn(x, y) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + torch._dynamo.reset() + counters.clear() + + compiled_fn(x, y) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) + torch._dynamo.reset() + counters.clear() + + # Changing the custom config should trigger a recompilation + custom_inductor_config.enable_optimisation = True + compiled_fn(x, y) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + def test_bypass_unsupported(self): """ Test _reduce_unsupported diff --git a/test/inductor/test_compile_subprocess.py b/test/inductor/test_compile_subprocess.py index 6eba88ecae97..04297c38bf29 100644 --- a/test/inductor/test_compile_subprocess.py +++ b/test/inductor/test_compile_subprocess.py @@ -18,7 +18,7 @@ from torch._inductor.compile_fx import _InProcessFxCompile, FxCompile, FxCompileMode from torch._inductor.graph import GraphLowering from torch._inductor.test_case import TestCase -from torch.testing._internal.common_utils import TEST_WITH_ASAN +from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, TEST_WITH_ASAN from torch.testing._internal.inductor_utils import ( GPU_TYPE, IS_BIG_GPU, @@ -29,6 +29,16 @@ ) +if IS_WINDOWS and IS_CI: + # TODO(xuhancn) : Debug and confirm pass_fds status on Windows. + sys.stderr.write( + "Almost UTs failed: pass_fds not supported on Windows, skip them on Windows.\n" + ) + if __name__ == "__main__": + sys.exit(0) + raise unittest.SkipTest("pass_fds not supported on Windows") + + # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 7f41613646d4..bb59b626bef1 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -26,7 +26,6 @@ run_fw_bw_and_get_code, ) from torch.fx.experimental.proxy_tensor import make_fx -from torch.nn.attention import sdpa_kernel, SDPBackend from torch.testing import FileCheck from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, @@ -178,10 +177,9 @@ def test_effn_attn_bias_padding_misaligned(self): inputs = [q, k, v, mask] def f(q, k, v, mask): - with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION): - return F.scaled_dot_product_attention( - q, k, v, attn_mask=mask, dropout_p=0.0 - ) + return F.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0 + ) f_compiled = torch.compile(f) @@ -1845,6 +1843,7 @@ def fn(x): self.assertEqual(graph.disable_cudagraphs_reason, None) self.assertEqual(graph.device_types, {"cuda"}) + @unittest.skipIf(IS_FBCODE, "Not runnable in fbcode") def test_triton_interpret(self): import subprocess @@ -2098,6 +2097,7 @@ def get_input() -> torch.Tensor: self.assertIn("znumel", code) @xfailIfPy312Plus # https://github.com/pytorch/pytorch/issues/142032 + @unittest.skipIf(config.is_fbcode(), "Dependence on functorch.einops") def test_repeated_masked_load(self): target_size = (8, 2) mem_eff_temporal_upsampling_interp_chunks = 2 diff --git a/test/inductor/test_cudacodecache.py b/test/inductor/test_cudacodecache.py index 970fe64a758d..36f73b200476 100644 --- a/test/inductor/test_cudacodecache.py +++ b/test/inductor/test_cudacodecache.py @@ -1,6 +1,7 @@ # Owner(s): ["module: inductor"] import ctypes +import unittest import torch from torch._inductor.async_compile import AsyncCompile @@ -9,6 +10,10 @@ from torch._inductor.exc import CUDACompileError from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import fresh_cache +from torch.testing._internal.inductor_utils import HAS_CUDA + + +requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") _SOURCE_CODE = r""" @@ -36,6 +41,7 @@ class TestCUDACodeCache(InductorTestCase): + @requires_cuda def test_cuda_load(self): with fresh_cache(): # Test both .o and .so compilation. @@ -63,12 +69,14 @@ def test_cuda_load(self): ) torch.testing.assert_close(y, expected_y) + @requires_cuda def test_compilation_error(self): with fresh_cache(): error_source_code = _SOURCE_CODE.replace("saxpy_device", "saxpy_wrong", 1) with self.assertRaises(CUDACompileError): CUDACodeCache.compile(error_source_code, "o") + @requires_cuda def test_async_compile(self): with fresh_cache(): async_compile = AsyncCompile() diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index c80bd13af361..3b230865bcd9 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -421,7 +421,9 @@ def forward(self, a, b, c): 2, 4, ], # guarantees > 1 choices - "force_disable_caches": True, + "fx_graph_cache": False, + "fx_graph_remote_cache": False, + "autotune_local_cache": False, } ): from torch._inductor.utils import run_and_get_code @@ -1530,7 +1532,8 @@ def mm(a, b): "max_autotune_gemm_backends": "ATEN,TRITON,CUTLASS", "cuda.cutlass_max_profiling_configs": 2, # needed for log searching - "force_disable_caches": True, + "fx_graph_cache": False, + "fx_graph_remote_cache": False, } ): with ( @@ -1552,6 +1555,178 @@ def mm(a, b): num_ops = int(match.group(1)) self.assertTrue(num_ops > 0, "The number of ops should be greater than 0") + @unittest.skipIf(not SM90OrLater, "need sm_90") + def test_maybe_append_choice_caching(self): + """ + Test if maybe_append_choice's caching leads to correct results and + shorter maybe_append_choice time. + """ + + NUM_ITERATIONS = 10 + + class TestModule(torch.nn.Module): + def forward(self, A, B): + for _ in range(NUM_ITERATIONS): + A = A @ B / 32 + return A + + model = TestModule().cuda() + A = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda") + B = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda").t() + + expected = model(A, B) + + # Track render calls + from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate + + original_render = CUTLASSGemmTemplate.render + render_call_count = 0 + + def counting_render(self, *args, **kwargs): + nonlocal render_call_count + render_call_count += 1 + return original_render(self, *args, **kwargs) + + with mock.patch.object(CUTLASSGemmTemplate, "render", counting_render): + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTLASS", + "fx_graph_cache": False, + "fx_graph_remote_cache": False, + "cuda.enable_caching_codegen": True, + } + ): + compiled_model = torch.compile(model, fullgraph=True) + actual = compiled_model(A, B) + + torch.testing.assert_close(actual, expected) + + # Check render call count: render is called uniquely for each codegen + # and for each finalized codegen. + self.assertEqual( + render_call_count, NUM_ITERATIONS + DEFAULT_INST_LEVEL_MM_CONFIG + ) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) + def test_multiple_mm(self): + """ + Test multiple matrix multiplications with different shapes in a single nn.Module. + """ + + class MultipleMMModel(torch.nn.Module): + def forward(self, a, b, c, d): + # First mm with shape (128, 64) @ (64, 32) -> (128, 32) + mm1 = a @ b + # Second mm with shape (256, 128) @ (128, 64) -> (256, 64) + mm2 = c @ d + return mm1, mm2 + + model = MultipleMMModel().cuda() + + # Create tensors with different shapes + a = torch.randn(128, 64).cuda().half() + b = torch.randn(32, 64).cuda().half().t() + c = torch.randn(256, 128).cuda().half() + d = torch.randn(64, 128).cuda().half().t() + + # Track render calls + from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate + + original_render = CUTLASSGemmTemplate.render + render_call_count = 0 + + def counting_render(self, *args, **kwargs): + nonlocal render_call_count + render_call_count += 1 + return original_render(self, *args, **kwargs) + + with mock.patch.object(CUTLASSGemmTemplate, "render", counting_render): + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTLASS", + "cuda.cutlass_max_profiling_configs": 2, + "fx_graph_cache": False, + "fx_graph_remote_cache": False, + "cuda.enable_caching_codegen": True, + } + ): + # Get expected results + expected = model(a, b, c, d) + + # Compile and run + compiled_model = torch.compile(model) + actual = compiled_model(a, b, c, d) + + # Verify results + torch.testing.assert_close(actual, expected) + + num_matmuls = 2 + self.assertEqual(render_call_count, num_matmuls + num_matmuls * 2) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) + def test_multiple_mm_with_dynamic_shape(self): + """ + Test multiple matrix multiplications where one has dynamic shapes. + """ + + class MultipleMMDynamicModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.c = torch.randn(64, 256).cuda().half() + self.d = torch.randn(128, 256).cuda().half().t() + + def forward(self, a, b): + # dynamic shape matmul + mm1 = a @ b + # static shape matmul + mm2 = self.c @ self.d + return mm1, mm2 + + model = MultipleMMDynamicModel().cuda() + + # Create tensors with different shapes + a = torch.randn(128, 64).cuda().half() + b = torch.randn(32, 64).cuda().half().t() + + # Track render calls + from torch._inductor.codegen.cuda.gemm_template import CUTLASSGemmTemplate + + original_render = CUTLASSGemmTemplate.render + render_call_count = 0 + + def counting_render(self, *args, **kwargs): + nonlocal render_call_count + render_call_count += 1 + return original_render(self, *args, **kwargs) + + with mock.patch.object(CUTLASSGemmTemplate, "render", counting_render): + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTLASS", + "cuda.cutlass_max_profiling_configs": 2, + "fx_graph_cache": False, + "fx_graph_remote_cache": False, + "cuda.enable_caching_codegen": True, + } + ): + # Get expected results + expected = model(a, b) + + # Compile and run + compiled_model = torch.compile(model, dynamic=True) + actual = compiled_model(a, b) + + # Verify results + torch.testing.assert_close(actual, expected) + + num_matmuls = 2 + self.assertEqual(render_call_count, num_matmuls + num_matmuls * 2) + @unittest.skipIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_cutlass_backend_matmul_same_tensor(self): @@ -1846,8 +2021,14 @@ def test_gemm_operation_serialization(self, arch: str, cuda_version: str): """ full_ops = _gen_ops_cached(arch, cuda_version) ops = pytree.tree_flatten(full_ops)[0] + + # sanity check self.assertGreater(len(ops), 1000, "Too few ops generated") + # test if configuration name is unique + op_config_names = [op.configuration_name() for op in ops] + self.assertEqual(len(op_config_names), len(set(op_config_names))) + serializer = get_cutlass_operation_serializer() self.assertIsNotNone(serializer) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index fa6400dd9c27..e14afcea81b0 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -5,6 +5,7 @@ import random import string import unittest +import warnings from collections import namedtuple from contextlib import contextmanager from dataclasses import dataclass @@ -4235,6 +4236,52 @@ def test_large_batch_heads_grid_dimension(self, device): self.assertEqual(key.grad.shape, key.shape) self.assertEqual(value.grad.shape, value.shape) + @supported_platform + def test_debug_flag_disables_internal_compilation(self, device): + """Test that _FLEX_ATTENTION_DISABLE_COMPILE_DEBUG flag bypasses internal compilation.""" + import torch.nn.attention.flex_attention as fa + + original_flag = fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG + original_warnings_shown = fa._WARNINGS_SHOWN.copy() + + try: + B, H, S, D = 1, 1, 128, 64 + query = torch.randn(B, H, S, D, device=device, dtype=torch.float32) + key = torch.randn(B, H, S, D, device=device, dtype=torch.float32) + value = torch.randn(B, H, S, D, device=device, dtype=torch.float32) + + def simple_score_mod(score, b, h, q_idx, kv_idx): + return score + + # Test with debug flag False - should warn + fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = False + fa._WARNINGS_SHOWN.clear() + + with self.assertWarns(UserWarning) as cm: + out_compiled = fa.flex_attention( + query, key, value, score_mod=simple_score_mod + ) + + self.assertIn( + "flex_attention called without torch.compile", str(cm.warning) + ) + + # Test with debug flag True - should NOT warn + fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True + + # Should not error + with warnings.catch_warnings(): + warnings.simplefilter("error") + out_debug = fa.flex_attention( + query, key, value, score_mod=simple_score_mod + ) + + torch.testing.assert_close(out_compiled, out_debug, rtol=1e-4, atol=1e-4) + + finally: + fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = original_flag + fa._WARNINGS_SHOWN = original_warnings_shown + class TestBlockMask(InductorTestCase): def setUp(self): diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py index 4d52775ccbad..a0e1b47032b8 100644 --- a/test/inductor/test_fused_attention.py +++ b/test/inductor/test_fused_attention.py @@ -1023,7 +1023,7 @@ def dot_prod_attention( return attn_weights.matmul(value), key, value tensor_shape = (4, 2, 16, 32) - attn_mask = torch.randn((1, 1, 1, 2), dtype=torch.float, device=self.device) + attn_mask = torch.randn((1, 1, 2, 2), dtype=torch.float, device=self.device) args = [ torch.randn(tensor_shape, device=self.device), torch.randn(tensor_shape, device=self.device), @@ -1036,6 +1036,16 @@ def dot_prod_attention( has_dropout=False, check_train=False, ) + # test attn_mask with stride of last dim != 1 + attn_mask_ = attn_mask.transpose(2, 3) + args[3] = attn_mask_ + self._check_common( + dot_prod_attention, + args1=args, + has_dropout=False, + check_train=False, + contains=self.device == "cpu", + ) def _test_sdpa_rewriter_23(self): def dot_prod_attention( @@ -1065,6 +1075,44 @@ def dot_prod_attention( check_train=False, ) + def _test_sdpa_rewriter_24(self): + def dot_prod_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor, + ) -> torch.Tensor: + """Input tensors assumed to have shape (batch_size, n_head, seq_len, embed_dim)""" + bs = query.size(0) + n_head = query.size(1) + seq_len = query.size(2) + embed_dim = query.size(3) + q = query.view(bs * n_head, seq_len, embed_dim) + k = key.reshape(bs * n_head, seq_len, embed_dim) + v = value.reshape(bs * n_head, seq_len, embed_dim) + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = attn_weights.view(bs, n_head, seq_len, seq_len) + attn_mask + attn_weights = attn_weights.view(bs * n_head, seq_len, seq_len) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_output = torch.bmm(attn_weights, v) + attn_output = attn_output.view(bs, n_head, seq_len, embed_dim) + return attn_output + + tensor_shape = (4, 2, 16, 32) + attn_mask = torch.randn((1, 1, 16, 16), dtype=torch.float, device=self.device) + args = [ + torch.randn(tensor_shape, device=self.device, dtype=torch.float), + torch.randn(tensor_shape, device=self.device, dtype=torch.float), + torch.randn(tensor_shape, device=self.device, dtype=torch.float), + attn_mask, + ] + self._check_common( + dot_prod_attention, + args1=args, + has_dropout=False, + check_train=False, + ) + if HAS_XPU or (HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION): @@ -1133,6 +1181,9 @@ class SDPAPatternRewriterGpuTests(TestSDPAPatternRewriterTemplate): test_sdpa_rewriter_23_gpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_23 ) + test_sdpa_rewriter_24_gpu = functools.partialmethod( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_24 + ) class SDPAPatternRewriterGpuDynamicTests(SDPAPatternRewriterGpuTests): use_static_shapes = False @@ -1199,6 +1250,9 @@ class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): test_sdpa_rewriter_23_cpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_23 ) + test_sdpa_rewriter_24_cpu = functools.partialmethod( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_24 + ) class SDPAPatternRewriterCpuDynamicTests(SDPAPatternRewriterCpuTests): use_static_shapes = False diff --git a/test/inductor/test_group_batch_fusion.py b/test/inductor/test_group_batch_fusion.py index 516120da2986..090a7e8e29d3 100644 --- a/test/inductor/test_group_batch_fusion.py +++ b/test/inductor/test_group_batch_fusion.py @@ -286,24 +286,6 @@ def forward(self, x): return torch.stack((stack_input, stack_other), dim=0) -@requires_gpu() -@torch._inductor.config.patch( - pre_grad_fusion_options={ - "batch_linear": {}, - "batch_linear_lhs": {}, - "batch_layernorm": {}, - "batch_tanh": {}, - "batch_relu": {}, - "batch_sigmoid": {}, - }, - post_grad_fusion_options={ - "batch_aten_add": {}, - "batch_aten_mul": {}, - "batch_aten_sub": {}, - "batch_aten_div": {}, - "group_linear": {"require_fbgemm": True}, - }, -) class TestGroupBatchFusion(TestCase): def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3): if len(set(ref_dict.keys())) != len(set(res_dict.keys())): @@ -332,7 +314,14 @@ def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3): self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol) ) + @requires_gpu() @unittest.skipIf(not has_fbgemm, "requires fbgemm") + @torch._inductor.config.patch( + pre_grad_fusion_options={}, + post_grad_fusion_options={ + "group_linear": {"require_fbgemm": True}, + }, + ) def test_group_linear_fusion(self): z = 10 for has_bias in [True, False]: @@ -355,13 +344,16 @@ def test_group_linear_fusion(self): counters["inductor"]["group_linear"], 4, ) - self.assertEqual( - counters["inductor"]["batch_aten_add"], - 0, - ) counters.clear() + @requires_gpu() @unittest.skipIf(not has_fbgemm, "requires fbgemm") + @torch._inductor.config.patch( + pre_grad_fusion_options={}, + post_grad_fusion_options={ + "group_linear": {"require_fbgemm": True}, + }, + ) def test_group_linear_fusion_different_shapes(self): counters.clear() module = MyModule2().eval().to(GPU_TYPE) @@ -386,13 +378,14 @@ def test_group_linear_fusion_different_shapes(self): counters["inductor"]["group_linear"], 2, ) - self.assertEqual( - counters["inductor"]["batch_aten_mul"], - 1, - ) counters.clear() + @requires_gpu() @unittest.skipIf(GPU_TYPE == "mps", "welford_reduce is yet not implemented for MPS") + @torch._inductor.config.patch( + pre_grad_fusion_options={"batch_layernorm": {}}, + post_grad_fusion_options={}, + ) def test_batch_layer_norm_fusion(self): for has_weight in [True, False]: for has_bias in [True, False]: @@ -410,6 +403,11 @@ def test_batch_layer_norm_fusion(self): self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) counters.clear() + @requires_gpu() + @torch._inductor.config.patch( + pre_grad_fusion_options={"batch_linear_lhs": {}}, + post_grad_fusion_options={}, + ) def test_batch_linear_lhs_fusion(self): z = 10 for has_bias in [True, False]: @@ -427,6 +425,11 @@ def test_batch_linear_lhs_fusion(self): self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) counters.clear() + @requires_gpu() + @torch._inductor.config.patch( + pre_grad_fusion_options={"batch_linear": {}}, + post_grad_fusion_options={}, + ) def test_batch_linear_pre_grad_fusion(self): for has_bias in [True, False]: counters.clear() @@ -443,6 +446,19 @@ def test_batch_linear_pre_grad_fusion(self): self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) counters.clear() + @requires_gpu() + @torch._inductor.config.patch( + pre_grad_fusion_options={ + "batch_relu": {}, + "batch_sigmoid": {}, + }, + post_grad_fusion_options={ + "batch_aten_add": {}, + "batch_aten_mul": {}, + "batch_aten_sub": {}, + "batch_aten_div": {}, + }, + ) def test_pointwise_op_fusion(self): counters.clear() module = TestPoitwiseOps(GPU_TYPE) diff --git a/test/inductor/test_kernel_optimization.py b/test/inductor/test_kernel_optimization.py new file mode 100644 index 000000000000..aabc8e83a06d --- /dev/null +++ b/test/inductor/test_kernel_optimization.py @@ -0,0 +1,92 @@ +# Owner(s): ["module: inductor"] + +import torch +import torch._inductor +from torch._dynamo.utils import counters +from torch._inductor.test_case import run_tests, TestCase +from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu + + +class TestEinsumtoPointwise(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward( + self, + input: torch.Tensor, + weights: torch.Tensor, + bias: torch.Tensor, + input2: torch.Tensor, + weights2: torch.Tensor, + bias2: torch.Tensor, + ) -> torch.Tensor: + output = torch.functional.einsum("bni, nio -> bno", input, weights) + add1 = output.add(bias) + output2 = torch.functional.einsum("bni, bnio -> bno", input2, weights2) + add2 = output2 + bias2 + return add1 + add2 + + +class TestKernelOptimization(TestCase): + def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3): + if len(set(ref_dict.keys())) != len(set(res_dict.keys())): + return False + for key1 in ref_dict.keys(): + key2 = "_orig_mod." + key1 + assert key2 in res_dict, f"{key1} does not exist in traced module" + if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol): + return False + return True + + def compare_pred(self, module, traced, input, rtol=1e-3, atol=1e-3): + ref = module(*input) + res = traced(*input) + self.assertEqual(ref, res, rtol=rtol, atol=atol) + + def compare_parameters(self, module, traced, rtol=1e-3, atol=1e-3): + ref_params = dict(module.named_parameters()) + res_params = dict(traced.named_parameters()) + self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol, atol)) + + def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3): + ref_grad = {key: param.grad for key, param in module.named_parameters()} + res_grad = {key: param.grad for key, param in traced.named_parameters()} + self.assertTrue( + self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol) + ) + + @requires_gpu() + @torch._inductor.config.patch( + pre_grad_fusion_options={ + "einsum_to_pointwise_pass": {}, + }, + post_grad_fusion_options={}, + ) + def test_einsum_to_pointwise(self): + counters.clear() + module = TestEinsumtoPointwise().to(GPU_TYPE) + input = [ + torch.randn(4096, 9, 512, device=GPU_TYPE, requires_grad=True), + torch.randn(9, 512, 96, device=GPU_TYPE, requires_grad=True), + torch.randn(9, 96, device=GPU_TYPE, requires_grad=True), + torch.randn(4096, 9, 160, device=GPU_TYPE, requires_grad=True), + torch.randn(4096, 9, 160, 96, device=GPU_TYPE, requires_grad=True), + torch.randn(4096, 9, 96, device=GPU_TYPE, requires_grad=True), + ] + traced = torch.compile(module) + ref = module(*input) + res = traced(*input) + ref.sum().backward() + res.sum().backward() + self.compare_pred(module, traced, input) + self.compare_parameters(module, traced) + self.compare_gradients(module, traced) + self.assertEqual( + counters["inductor"]["einsum_to_pointwise_pass"], + 1, + ) + counters.clear() + + +if __name__ == "__main__": + run_tests() diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 9f40c7d3d23e..e451067be59a 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -50,7 +50,12 @@ aten = torch.ops.aten from torch._inductor.mock_cache import global_stats, PatchCaches, Stats from torch._inductor.test_case import run_tests, TestCase -from torch._inductor.utils import fresh_cache, run_and_get_code +from torch._inductor.utils import ( + fresh_cache, + get_k_splits, + run_and_get_code, + use_decompose_k_choice, +) from torch._inductor.virtualized import V from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck @@ -752,7 +757,12 @@ def test_cat_max_autotune_extern(self): @skipIfXpu( msg="The fusion not happened because it do not speedup on XPU, see issue #146568" ) - @config.patch(max_autotune_gemm_backends="TRITON") + @config.patch( + { + "max_autotune_gemm_backends": "TRITON", + "benchmark_epilogue_fusion": False, + } + ) def test_cat_max_autotune_triton(self): self._test_cat_max_autotune_impl(using_triton_mm=True) @@ -810,9 +820,9 @@ def test_non_contiguous_input_mm(self): Check https://github.com/pytorch/pytorch/issues/125437 for more details. """ x = rand_strided( - (50257, 32768), (1, 50304), dtype=torch.bfloat16, device=GPU_TYPE + (50257, 2048), (1, 50304), dtype=torch.bfloat16, device=GPU_TYPE ) - y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device=GPU_TYPE) + y = rand_strided((2048, 768), (768, 1), dtype=torch.bfloat16, device=GPU_TYPE) @torch.compile(mode="max-autotune") def f(x, y): @@ -825,9 +835,9 @@ def f(x, y): def test_non_contiguous_input_addmm(self): b = torch.randn((768), dtype=torch.bfloat16, device=GPU_TYPE) x = rand_strided( - (50257, 32768), (1, 50304), dtype=torch.bfloat16, device=GPU_TYPE + (50257, 2048), (1, 50304), dtype=torch.bfloat16, device=GPU_TYPE ) - y = rand_strided((32768, 768), (768, 1), dtype=torch.bfloat16, device=GPU_TYPE) + y = rand_strided((2048, 768), (768, 1), dtype=torch.bfloat16, device=GPU_TYPE) @torch.compile(mode="max-autotune") def f(x, y): @@ -839,10 +849,10 @@ def f(x, y): def test_non_contiguous_input_bmm(self): x = rand_strided( - (1, 50257, 32768), (0, 1, 50304), dtype=torch.bfloat16, device=GPU_TYPE + (1, 50257, 2048), (0, 1, 50304), dtype=torch.bfloat16, device=GPU_TYPE ) y = rand_strided( - (1, 32768, 768), (0, 768, 1), dtype=torch.bfloat16, device=GPU_TYPE + (1, 2048, 768), (0, 768, 1), dtype=torch.bfloat16, device=GPU_TYPE ) @torch.compile(mode="max-autotune") @@ -856,16 +866,12 @@ def f(x, y): # TODO: fix accuracy failure of the triton template on XPU. # and enable this test case. @skipIfXpu - @unittest.skipIf( - os.getenv("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1", - "OOM when running with TORCHINDUCTOR_CPP_WRAPPER https://github.com/pytorch/pytorch/issues/126867", - ) def test_non_contiguous_input_mm_plus_mm(self): - x1 = rand_strided((50257, 32768), (1, 50304), device=GPU_TYPE) - y1 = rand_strided((32768, 768), (768, 1), device=GPU_TYPE) + x1 = rand_strided((50257, 2048), (1, 50304), device=GPU_TYPE) + y1 = rand_strided((2048, 768), (768, 1), device=GPU_TYPE) - x2 = rand_strided((50257, 32768), (1, 50304), device=GPU_TYPE) - y2 = rand_strided((32768, 768), (768, 1), device=GPU_TYPE) + x2 = rand_strided((50257, 2048), (1, 50304), device=GPU_TYPE) + y2 = rand_strided((2048, 768), (768, 1), device=GPU_TYPE) @torch.compile(mode="max-autotune") def f(x1, y1, x2, y2): @@ -1493,7 +1499,9 @@ def misses(): self.assertEqual(hits(), 4) self.assertEqual(misses(), 4) + @fresh_cache() @skipIfXpu + @unittest.skipIf(TEST_WITH_ROCM, "decompose_k not supported on ROCm") @unittest.skipIf( config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet" ) @@ -1501,19 +1509,42 @@ def misses(): max_autotune=True, max_autotune_gemm_backends="TRITON", autotune_fallback_to_aten=False, - disable_decompose_k=True, ) - def test_max_autotune_disable_decompose_K(self): - M, N, K = (32, 32, 32768) - - a = torch.randn(M, K, dtype=torch.float16, device="cuda", requires_grad=True) - b = torch.randn(K, N, dtype=torch.float16, device="cuda", requires_grad=True) - - compiled_func = torch.compile(lambda a, b: a @ b) - out, code = run_and_get_code(compiled_func, a, b) - - for codegen in code: - FileCheck().check_not("decompose_k").run(codegen) + @parametrize("num_decompose_k_splits", (0, 5, 20)) + @parametrize("decompose_k_threshold", (8, 16)) + def test_max_autotune_decompose_k_envvars( + self, num_decompose_k_splits, decompose_k_threshold + ): + shapes = [(32, 32, 32768), (32, 32, 256)] + for M, N, K in shapes: + get_k_splits.cache_clear() + use_decompose_k_choice.cache_clear() + a = torch.randn(M, K, dtype=torch.float16, device="cuda") + b = torch.randn(K, N, dtype=torch.float16, device="cuda") + + with config.patch( + { + "triton.num_decompose_k_splits": num_decompose_k_splits, + "triton.decompose_k_threshold": decompose_k_threshold, + } + ): + compiled_func = torch.compile(lambda a, b: a @ b) + _, code = run_and_get_code(compiled_func, a, b) + + decompose_count = 0 + for codegen in code: + if "benchmark_decompose_k_mm" in codegen: + decompose_count += 1 + + if ( + K // M < decompose_k_threshold + or K // N < decompose_k_threshold + or num_decompose_k_splits == 0 + ): + self.assertEqual(decompose_count, 0) + else: + self.assertTrue(decompose_count > 0) + self.assertTrue(decompose_count <= num_decompose_k_splits) @skipIfXpu @unittest.skipIf( diff --git a/test/inductor/test_memory.py b/test/inductor/test_memory.py index eaff539f7a49..3e23442b38ec 100644 --- a/test/inductor/test_memory.py +++ b/test/inductor/test_memory.py @@ -8,6 +8,7 @@ from torch._inductor import config, memory from torch._inductor.test_case import TestCase from torch._inductor.utils import run_and_get_triton_code +from torch.testing._internal.common_utils import serialTest from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU @@ -306,6 +307,58 @@ def f(a, b, c): expected_bound = a.size(0) * c.size(1) * a.dtype.itemsize * 2 self.assertLess(peak_mem, expected_bound) + @serialTest() + def test_fusion_acc_large_reads(self): + def f(x, y, z): + res = torch.zeros_like(x[0]) + for i in range(4): + temp = torch.matmul(x, y) + z + res = res + temp + return res + + N = 128 + x = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) + y = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) + z = torch.rand(N, N, dtype=torch.float32, device=GPU_TYPE) + + # CASE 1: no restriction on the amount of accumulation + with config.patch({"realize_acc_reads_size_threshold": float("inf")}): + f_compiled = torch.compile(f) + code = run_and_get_triton_code(f_compiled, x, y, z) + ( + FileCheck() + .check("triton_poi_fused_add_0.run(buf4, arg2_1, buf1, buf2, buf3") + .run(code) + ) + + # CASE 2: for tensors with the same size as x (which is 4 * N**2 bytes) + # at most 12 / 4 = 3 reads can be accumulated during fusion + with config.patch({"realize_acc_reads_size_threshold": 12 * N**2}): + f_compiled = torch.compile(f) + code = run_and_get_triton_code(f_compiled, x, y, z) + ( + FileCheck() + .check("triton_poi_fused_add_0.run(buf3, arg2_1, buf1, buf2,") + .check("triton_poi_fused_add_1.run(buf5, buf4, arg2_1,") + .run(code) + ) + + # CASE 3: no such fusion allowed + with config.patch({"realize_acc_reads_size_threshold": N**2}): + f_compiled = torch.compile(f) + code = run_and_get_triton_code(f_compiled, x, y, z) + ( + FileCheck() + .check("triton_poi_fused_add_0.run(buf1, arg2_1,") + .check("triton_poi_fused_add_0.run(buf3, arg2_1,") + .check("triton_poi_fused_add_0.run(buf4, buf3,") + .check("triton_poi_fused_add_0.run(buf6, arg2_1,") + .check("triton_poi_fused_add_0.run(buf7, buf6,") + .check("triton_poi_fused_add_0.run(buf9, arg2_1,") + .check("triton_poi_fused_add_0.run(buf10, buf9,") + .run(code) + ) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 7760bfd834ef..79ca002f7f5b 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -18,7 +18,7 @@ from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer from torch.nn import functional as F from torch.testing._internal.common_device_type import instantiate_device_type_tests -from torch.testing._internal.common_mkldnn import bf32_on_and_off +from torch.testing._internal.common_mkldnn import reduced_f32_on_and_off from torch.testing._internal.common_quantization import ( _generate_qdq_quantized_model, skipIfNoDynamoSupport, @@ -312,6 +312,11 @@ def forward(self, x): memory_format, dtype, ) in options: + if ( + dtype != torch.float32 + and torch.backends.mkldnn.matmul.fp32_precision == "tf32" + ): + continue metrics.reset() if dim == 4: x_shape = (1, 3, 56, 56) @@ -350,7 +355,7 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv2d_unary(self, device): self.device = device self._test_conv_unary_base(dim=4) @@ -358,7 +363,7 @@ def test_conv2d_unary(self, device): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv3d_unary(self, device): self.device = device self._test_conv_unary_base(dim=5) @@ -442,7 +447,7 @@ def matcher_check_fn(): @skipIfXpu( msg="The operator 'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device." ) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv_transpose2d_unary(self, device): self.device = device self._test_conv_transpose_unary_base(dim=4) @@ -453,7 +458,7 @@ def test_conv_transpose2d_unary(self, device): @skipIfXpu( msg="The operator 'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device." ) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv_transpose3d_unary(self, device): self.device = device self._test_conv_transpose_unary_base(dim=5) @@ -508,6 +513,11 @@ def forward(self, x): memory_format, dtype, ) in options: + if ( + dtype != torch.float32 + and torch.backends.mkldnn.matmul.fp32_precision == "tf32" + ): + continue metrics.reset() if dim == 4: x_shape = (1, 3, 56, 56) @@ -543,7 +553,7 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - @bf32_on_and_off(0.02) + @reduced_f32_on_and_off(0.02) def test_conv2d_binary(self, device): self.device = device self._test_conv_binary_base(dim=4) @@ -551,7 +561,7 @@ def test_conv2d_binary(self, device): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - @bf32_on_and_off(0.02) + @reduced_f32_on_and_off(0.02) def test_conv3d_binary(self, device): self.device = device self._test_conv_binary_base(dim=5) @@ -650,7 +660,7 @@ def matcher_check_fn(): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv2d_binary_broadcast_shapes(self, device): self.device = device self._test_conv_binary_broadcast_shapes_base(dim=4) @@ -658,7 +668,7 @@ def test_conv2d_binary_broadcast_shapes(self, device): @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv3d_binary_broadcast_shapes(self, device): self.device = device self._test_conv_binary_broadcast_shapes_base(dim=5) @@ -667,7 +677,7 @@ def test_conv3d_binary_broadcast_shapes(self, device): @skipIfNoONEDNN @skipIfRocm @unittest.skipIf(IS_FBCODE, "Failing in fbcode") - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv2d_linear_add_broadcast_shapes(self, device): self.device = device @@ -699,7 +709,7 @@ def matcher_check_fn(): class TestPatternMatcher(TestPatternMatcherBase): - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_linear_unary(self, device="cpu"): self.device = device @@ -730,10 +740,15 @@ def forward(self, x): dtypes.append(torch.bfloat16) if is_mkldnn_fp16_supported(self.device): dtypes.append(torch.float16) - if torch.backends.mkldnn.matmul.fp32_precision == "bf16": + if torch.backends.mkldnn.matmul.fp32_precision in ["bf16", "tf32"]: dtypes.append(torch.float32) options = itertools.product(unary_list, [True, False], dtypes) for unary_fn, bias, dtype in options: + if ( + dtype != torch.float32 + and torch.backends.mkldnn.matmul.fp32_precision == "tf32" + ): + continue metrics.reset() mod = M(unary_fn, 10, 30, bias=bias).eval() # only fuse for linear when the dtype is bf16 @@ -761,7 +776,7 @@ def matcher_check_fn(): expected_kernel_count -= 1 self.assertEqual(metrics.generated_kernel_count, expected_kernel_count) - @bf32_on_and_off() + @reduced_f32_on_and_off() @unittest.skipIf(not TEST_MKL, "Test requires MKL") def test_linear_fp32(self, device="cpu"): self.device = device @@ -909,7 +924,7 @@ def matcher_check_fn(): # 1 kernel for "to_lowp", 2 kernels for unary ops self.assertEqual(metrics.generated_kernel_count, 3) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_linear_binary(self, device="cpu"): self.device = device @@ -931,7 +946,7 @@ def forward(self, x, y): dtypes.append(torch.bfloat16) if is_mkldnn_fp16_supported(self.device): dtypes.append(torch.float16) - if torch.backends.mkldnn.matmul.fp32_precision == "bf16": + if torch.backends.mkldnn.matmul.fp32_precision in ["bf16", "tf32"]: dtypes.append(torch.float32) options = itertools.product( binary_list, [[2, 3, 10], [2, 10]], [True, False], dtypes @@ -940,6 +955,11 @@ def forward(self, x, y): for binary_fn, input_shape, bias, dtype in options: metrics.reset() + if ( + dtype != torch.float32 + and torch.backends.mkldnn.matmul.fp32_precision == "tf32" + ): + continue def matcher_check_fn(): self.assertEqual( @@ -2952,6 +2972,104 @@ def test_qlinear_add_int8_mixed_bf16_xpu(self, use_relu, is_qat, is_dynamic): is_dynamic=is_dynamic, ) + def _test_qlinear_fp8_inductor_cpu_helper(self, qlinear_op, post_op="none"): + dtype = torch.float8_e4m3fn + qlinear_prepack = torch.ops.onednn.qlinear_prepack + post_op_algo = "none" + unary_post_op_args = () + batch_size = 1 + output_dtype = torch.float8_e4m3fn + y_scale, y_zp = 0.07, 0 + ic = 4 + oc = 16 + + torch._dynamo.reset() + used_y_scale = y_scale + used_y_zp = y_zp + x = torch.rand(batch_size, ic) + w = torch.rand(oc, ic) + qx = x.to(dtype) + qw = w.to(dtype) + x_scale = 0.5 + w_scales = torch.randn(oc) + b = torch.rand(oc) + + x_zp = 0 + w_zps = torch.zeros_like(w_scales, dtype=torch.int) + + if post_op == "none": + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.qw_packed = qlinear_prepack(qw, x.shape) + + def forward(self, qx): + qy = qlinear_op( + qx, + x_scale, + x_zp, + self.qw_packed, + w_scales, + w_zps, + b, + used_y_scale, + used_y_zp, + output_dtype, + post_op, + unary_post_op_args, + post_op_algo, + ) + return qy + + elif post_op == "add": + x2 = torch.rand(batch_size, oc) + binary_alpha = 1.0 # we only support alpha=1.0 now + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.qw_packed = qlinear_prepack(qw, x.shape) + + def forward(self, qx): + qy = qlinear_op( + qx, + x_scale, + x_zp, + self.qw_packed, + w_scales, + w_zps, + x2, + b, + used_y_scale, + used_y_zp, + output_dtype, + 1.0, + 0, + "add", + binary_alpha, + "none", + unary_post_op_args, + post_op_algo, + ) + return qy + + with torch.no_grad(): + model = Mod() + y_refe = model(qx) + y_test = torch.compile(model)(qx) + self.assertEqual(y_refe.float(), y_test.float()) + + @skipIfNoONEDNN + def test_qlinear_fp8_inductor_cpu(self): + qlinear_op = torch.ops.onednn.qlinear_pointwise.default + self._test_qlinear_fp8_inductor_cpu_helper(qlinear_op, "none") + + @skipIfNoONEDNN + def test_qlinear_add_fp8_inductor_cpu(self): + qlinear_op = torch.ops.onednn.qlinear_pointwise.binary + self._test_qlinear_fp8_inductor_cpu_helper(qlinear_op, "add") + def _qlinear_dequant_promotion_test_helper( self, inputs, diff --git a/test/inductor/test_multi_kernel.py b/test/inductor/test_multi_kernel.py index 2966f3ac1d91..f576016cf08c 100644 --- a/test/inductor/test_multi_kernel.py +++ b/test/inductor/test_multi_kernel.py @@ -16,6 +16,7 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, + skipIfRocm, skipIfXpu, ) from torch.testing._internal.inductor_utils import ( @@ -98,6 +99,8 @@ def test_softmax(self, expect_multi_kernel=True): self.assertFalse(_contains_multi_kernel_code(wrapper_code)) @requires_triton() + # TODO: bobrenjc93 to fix multi-kernel for ROCM + @skipIfRocm @unittest.skipIf(not IS_BIG_GPU, "templates require big gpu") def test_triton_gemm(self): def fn(x, y): @@ -123,6 +126,8 @@ def fn(x, y): self.assertTrue(_contains_multi_kernel_code(wrapper_code)) @requires_triton() + # TODO: bobrenjc93 to fix multi-kernel for ROCM + @skipIfRocm @unittest.skipIf(not IS_BIG_GPU, "templates require big gpu") def test_triton_relu_fused_gemm(self): def fn(x, y): diff --git a/test/inductor/test_provenance_tracing.py b/test/inductor/test_provenance_tracing.py index c0efa7416ae1..2dd9ca44eb68 100644 --- a/test/inductor/test_provenance_tracing.py +++ b/test/inductor/test_provenance_tracing.py @@ -9,9 +9,15 @@ from pathlib import Path import torch +from torch._dynamo.utils import detect_fake_mode from torch._inductor import config -from torch._inductor.debug import create_node_mapping +from torch._inductor.debug import ( + create_mapping_pre_post_grad_nodes, + create_node_mapping_kernel_to_post_grad, +) +from torch._inductor.fx_passes.post_grad import post_grad_passes from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.virtualized import V from torch.testing._internal.inductor_utils import HAS_GPU from torch.testing._internal.triton_utils import requires_cuda @@ -56,6 +62,7 @@ def forward(self, a): @config.patch("trace.enabled", True) +@config.patch("trace.provenance_tracking", True) class TestProvenanceTracingArtifact(TestCase): """ This test checks that generated provenance tracing artifact from "post_grad" to @@ -121,6 +128,10 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): "mul_2", ], } + if backend == "aot_inductor": + expected_data["aoti_torch_cuda_mm_out"] = ["mm_default"] + else: + expected_data["extern_kernels.mm"] = ["mm_default"] self._check_provenance_tracing_artifact(filepath, expected_data) expected_mapping = [ ( @@ -171,6 +182,16 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): }, ), ] + if backend == "aot_inductor": + expected_mapping[0][1]["aoti_torch_cuda_mm_out"] = [ + "mm_default" + ] + expected_mapping[1][1]["mm_default"] = [ + "aoti_torch_cuda_mm_out" + ] + else: + expected_mapping[0][1]["extern_kernels.mm"] = ["mm_default"] + expected_mapping[1][1]["mm_default"] = ["extern_kernels.mm"] self._check_provenance_tracking_node_mappings( filepath, expected_mapping ) @@ -180,7 +201,7 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): if backend == "aot_inductor": expected_data = { "cpp_fused_mul_0": ["mul"], - "aoti_torch_cpu_addmm_out": ["addmm", "mul"], + "aoti_torch_cpu_addmm_out": ["addmm"], "cpp_fused_gelu_1": [ "mul_3", "mul_1", @@ -193,7 +214,6 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): # backend == "inductor" expected_data = { "cpp_fused_mul_0": ["mul"], - "aoti_torch_cpu_addmm_out": ["addmm", "mul"], "cpp_fused_gelu_1": [ "mul_3", "mul_1", @@ -201,7 +221,7 @@ def _test_triton_kernel_to_post_grad_tracing(self, device): "erf", "mul_2", ], - "extern_kernels.addmm": ["addmm", "mul"], + "extern_kernels.addmm": ["addmm"], } self._check_provenance_tracing_artifact(filepath, expected_data) @@ -252,14 +272,12 @@ def test_triton_kernel_to_post_grad_tracing_extern_kernel(self): filepath = Path(m.group(1)) if backend == "inductor": expected_data = { - "aoti_torch_cuda_addmm_out": ["addmm", "_tensor_constant1"], - "triton_poi_fused_0": ["_tensor_constant1"], "extern_kernels.addmm": ["addmm"], } else: # backend = aot_inductor expected_data = { - "aoti_torch_cuda_addmm_out": ["addmm", "_tensor_constant1"], + "aoti_torch_cuda_addmm_out": ["addmm"], "triton_poi_fused_0": ["_tensor_constant1"], } self._check_provenance_tracing_artifact(filepath, expected_data) @@ -374,11 +392,17 @@ def test_create_node_mapping(self): "triton_poi_fused_addmm_relu_sigmoid_0": ["relu", "add_tensor"] } - result = create_node_mapping( + result = create_mapping_pre_post_grad_nodes( pre_grad_graph_id, post_to_pre_grad_nodes_json, - triton_kernel_to_post_grad_json, ) + result = { + **result, + **create_node_mapping_kernel_to_post_grad( + triton_kernel_to_post_grad_json, + ), + } + self.assertEqual( result, { @@ -406,5 +430,58 @@ def test_create_node_mapping(self): ) +class TestProvenanceTracingNodeMeta(TestCase): + def get_node_with_target(self, gm, target): + """ + Return first node in gm with target + """ + return next(iter([node for node in gm.graph.nodes if node.target == target])) + + @requires_cuda # test only works for cuda pattern matcher + def test_pattern_matcher_transfer_meta(self): + """ + Test that stack trace is transfered when node is decomposed in post_grad_passes + """ + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 16) + self.relu = torch.nn.ReLU() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x): + x = self.fc1(x) + x = self.relu(x) + x = self.sigmoid(x) + return x * 3 + + x = torch.randn(8, 10).to("cuda") + example_inputs = (x,) + model = Model().to("cuda") + + # mimic the before_post_grad graph + ep = torch.export.export(model, example_inputs).run_decompositions() + gm = ep.module() + + # Set fake mode for V + fake_inputs = [ + node.meta.get("val") for node in gm.graph.nodes if node.op == "placeholder" + ] + fake_mode = detect_fake_mode(fake_inputs) + V.set_fake_mode(fake_mode) + + addmm_node = self.get_node_with_target(gm, torch.ops.aten.addmm.default) + stack_trace = addmm_node.meta["stack_trace"] + + post_grad_passes(gm, True) # for this test is_inference doesn't matter + + mm_node = self.get_node_with_target(gm, torch.ops.aten.mm.default) + add_node = self.get_node_with_target(gm, torch.ops.aten.add.Tensor) + + self.assertEqual(add_node.meta["stack_trace"], stack_trace) + self.assertEqual(mm_node.meta["stack_trace"], stack_trace) + + if __name__ == "__main__": run_tests() diff --git a/test/inductor/test_remote_cache.py b/test/inductor/test_remote_cache.py new file mode 100644 index 000000000000..591713403bb8 --- /dev/null +++ b/test/inductor/test_remote_cache.py @@ -0,0 +1,76 @@ +# Owner(s): ["module: inductor"] +from dataclasses import dataclass + +from torch._inductor.remote_cache import ( + RemoteCache, + RemoteCacheBackend, + RemoteCachePassthroughSerde, +) +from torch.testing._internal.common_utils import TestCase + + +class FailingBackend(RemoteCacheBackend): + def _get(self, key): + raise AssertionError("testget") + + def _put(self, key, data): + raise AssertionError("testput") + + +class NoopBackend(RemoteCacheBackend): + def _get(self, key): + return None + + def _put(self, key, data): + return None + + +@dataclass +class TestSample: + fail: str = None + + +class FakeCache(RemoteCache): + def __init__(self): + super().__init__(FailingBackend(), RemoteCachePassthroughSerde()) + + def _create_sample(self): + return TestSample() + + def _log_sample(self, sample): + self.sample = sample + + +class TestRemoteCache(TestCase): + def test_normal_logging( + self, + ) -> None: + c = RemoteCache(NoopBackend(), RemoteCachePassthroughSerde()) + c.put("test", "value") + c.get("test") + + def test_failure_no_sample( + self, + ) -> None: + c = RemoteCache(FailingBackend(), RemoteCachePassthroughSerde()) + with self.assertRaises(AssertionError): + c.put("test", "value") + with self.assertRaises(AssertionError): + c.get("test") + + def test_failure_logging( + self, + ) -> None: + c = FakeCache() + with self.assertRaises(AssertionError): + c.put("test", "value") + self.assertEqual(c.sample.fail_reason, "testput") + with self.assertRaises(AssertionError): + c.get("test") + self.assertEqual(c.sample.fail_reason, "testget") + + +if __name__ == "__main__": + from torch._inductor.test_case import run_tests + + run_tests() diff --git a/test/inductor/test_select_algorithm.py b/test/inductor/test_select_algorithm.py index 66781b4d7622..d2cd77fe5cd2 100644 --- a/test/inductor/test_select_algorithm.py +++ b/test/inductor/test_select_algorithm.py @@ -1,5 +1,8 @@ # Owner(s): ["module: inductor"] +import contextlib import functools +import unittest.mock +from typing import Callable from unittest.mock import patch import torch @@ -9,11 +12,25 @@ import torch.nn.functional as F from torch._dynamo.testing import expectedFailureDynamicWrapper from torch._dynamo.utils import counters +from torch._inductor import config from torch._inductor.autotune_process import TritonBenchmarkRequest +from torch._inductor.ir import FixedLayout +from torch._inductor.select_algorithm import ( + autotune_select_algorithm, + PartialRender, + TritonTemplate, + TritonTemplateKernel, +) from torch._inductor.test_case import run_tests, TestCase -from torch._inductor.utils import is_big_gpu +from torch._inductor.utils import is_big_gpu, run_and_get_kernels +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm, skipIfXpu -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_GPU, + requires_gpu, + requires_triton, +) aten = torch.ops.aten @@ -402,6 +419,144 @@ def test_TritonTemplateCaller_str(self): self.assertEqual(caller_str, f"TritonTemplateCaller({module_path}, extra)") +@contextlib.contextmanager +def patch_lowering(lowering_overrides) -> Callable[[], None]: + import torch._inductor.lowering as inductor_lowering + + with unittest.mock.patch.dict(inductor_lowering.lowerings): + for fn, ( + decomp_fn, + broadcast, + type_promotion_kind, + convert_input_to_bool, + ) in lowering_overrides.items(): + inductor_lowering._register_lowering( + fn, + decomp_fn, + broadcast=broadcast, + type_promotion_kind=type_promotion_kind, + convert_input_to_bool=convert_input_to_bool, + lowering_dict=inductor_lowering.lowerings, + ) + + yield + + +class TestTemplateRender(TestCase): + @requires_gpu() + @requires_triton() + @config.patch(cuda_backend="triton") + def test_finalized_subclass_hooks(self): + """ + Tests that all registered triton template hooks have been finalized, + especially in the case that the hooks are finalized manually by the + caller i.e. by calling template.finalize_hook(hook_name) + """ + hook_identifier = "# CUSTOM_HOOK" + + class ExtensionTritonTemplateKernel(TritonTemplateKernel): + def custom_hook(self) -> str: + """ + Custom hook that just returns a test string for + validation + """ + + def hook() -> str: + return hook_identifier + + assert "" not in self.render_hooks + self.render_hooks[""] = hook + return "" + + def render( + self, template, kwargs, record_input_dependent_tracked_event=False + ): + if record_input_dependent_tracked_event: + self.cached_replay_events = [] + + template_env = { + fn.__name__: self.record_input_dependent_tracked_event()(fn) + if record_input_dependent_tracked_event + else fn + for fn in [ + self.def_kernel, + self.size, + self.stride, + self.store_output, + self.load_input, + self.make_load, + self.modification, + self.gen_argdefs, + self.gen_defines, + # This function registers a hook that the scheduler does + # not directly finalize + self.custom_hook, + ] + } + return PartialRender( + template.render(**template_env, **kwargs), + self.render_hooks, + ) + + class ExtensionTritonTemplate(TritonTemplate): + kernel_type = ExtensionTritonTemplateKernel + + add_template = ExtensionTritonTemplate( + name="add", + grid=lambda *args, **kwargs: (1, 1, 1), + source=( + r""" +{{def_kernel("A", "B")}} + {{custom_hook()}} + xoffset = tl.program_id(0) + xindex = xoffset + tl.arange(0, XBLOCK) + xmask = tl.full([XBLOCK], True, tl.int1) + tmp0 = tl.load(A + xindex) + tmp1 = tl.load(B + xindex) + tmp2 = tmp0 + tmp1 + {{store_output(("xindex",), "tmp2", mask="xmask")}} + """ + ), + ) + + XBLOCK = 32 + + def add_override(a, b, alpha=None): + layout = FixedLayout(a.get_device(), a.get_dtype(), a.get_size()) + choices = [] + add_template.maybe_append_choice( + choices, + input_nodes=(a, b), + layout=layout, + num_stages=1, + num_warps=2, + XBLOCK=XBLOCK, + ) + return autotune_select_algorithm("add", choices, [a, b], layout) + + with patch_lowering( + { + torch.ops.aten.add.Tensor: ( + add_override, + True, + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + False, + ) + } + ): + + @torch.compile + def add(a, b): + return a + b + + a = torch.zeros((XBLOCK,), device=GPU_TYPE) + b = torch.zeros((XBLOCK,), device=GPU_TYPE) + + _result, kernels = run_and_get_kernels(add, a, b) + assert len(kernels) == 1 + assert hook_identifier in kernels[0] + + if __name__ == "__main__": if IS_LINUX and HAS_GPU and is_big_gpu(): run_tests() diff --git a/test/inductor/test_split_cat_fx_aten_passes.py b/test/inductor/test_split_cat_fx_aten_passes.py index 9300c7d0d126..354552c497d9 100644 --- a/test/inductor/test_split_cat_fx_aten_passes.py +++ b/test/inductor/test_split_cat_fx_aten_passes.py @@ -49,6 +49,22 @@ def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): return torch.ops.aten.cat.default([cat_1, cat_2], 1) +class TestSplitCatSingular(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): + cat = torch.ops.aten.cat.default([x, y], 1) + split = torch.ops.aten.split.Tensor(cat, 32, 1) + getitem = split[0] + cat_1 = torch.ops.aten.cat.default( + [getitem], + 1, + ) + cat_2 = torch.ops.aten.cat.default([getitem, z], 1) + return torch.ops.aten.cat.default([cat_1, cat_2], 1) + + class TestSplitCatPartial(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -275,6 +291,32 @@ def test_split_cat_post_grad(self): self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) counters.clear() + @requires_cuda + @torch._inductor.config.patch( + pre_grad_fusion_options={}, + post_grad_fusion_options={ + "normalization_aten_pass": {}, + "split_cat_aten_pass": {"threshold_to_cat": 5}, + }, + ) + def test_split_cat_post_grad_singular(self): + counters.clear() + inputs = [ + torch.randn(1024, 128, device=torch.device(device=GPU_TYPE)), + torch.randn(1024, 128, device=torch.device(device=GPU_TYPE)), + torch.randn(1024, 32, device=torch.device(device=GPU_TYPE)), + ] + module = TestSplitCatSingular() + traced = torch.compile(module) + ref = module(*inputs) + res = traced(*inputs) + self.compare_pred(module, traced, inputs) + self.assertEqual(counters["inductor"]["normalization_aten_pass"], 4) + self.assertEqual(counters["inductor"]["split_cat_aten_pass"], 0) + self.assertEqual(ref, res, rtol=1e-8, atol=1e-8) + self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) + counters.clear() + @requires_cuda @torch._inductor.config.patch( pre_grad_fusion_options={}, diff --git a/test/inductor/test_static_cuda_launcher.py b/test/inductor/test_static_cuda_launcher.py index 477d5ac2e6c2..2ce294ed0ff5 100644 --- a/test/inductor/test_static_cuda_launcher.py +++ b/test/inductor/test_static_cuda_launcher.py @@ -14,7 +14,6 @@ from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import skipIfRocm from torch.testing._internal.triton_utils import requires_cuda -from torch.torch_version import TorchVersion @requires_cuda @@ -140,36 +139,6 @@ def signed_integers( launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50) self.assertEqual(new_arg0, arg0) - # TODO: floats don't work properly, triton seems to think they're all tl.float32 - # despite type annotations. - # There's also not really a good way for me to make a float16 in python... - @skipIfRocm - def test_floats(self): - @triton.jit - def floats(arg0, arg1: tl.float16, arg2: tl.float32, arg3: tl.float64): - x = tl.load(arg0) - y = arg1 + arg2 + arg3 - tl.store(arg0, x + y) - - arg0 = torch.zeros(1, dtype=torch.float64, device="cuda") - - args = (arg0, 1.0, 1.0, 1.0) - - compiled_kernel = floats[1,](*args) - launcher = self._make_launcher(compiled_kernel) - if TorchVersion(triton.__version__) >= TorchVersion("3.4.0"): - self.assertEqual(launcher.arg_tys, "Offd") - else: - self.assertEqual(launcher.arg_tys, "Offf") - # TODO this line fails on Triton 3.4.0 (https://github.com/triton-lang/triton/issues/6176) - # Add the check back when this is fixed in Triton - # self.assertEqual(arg0, torch.tensor([3.0], dtype=torch.float64, device="cuda")) - new_arg0 = torch.zeros(1, dtype=torch.float64, device="cuda") - device_interface = get_interface_for_device("cuda") - stream = device_interface.get_raw_stream(device_interface.current_device()) - launcher.run(1, 1, 1, stream, new_arg0, 1.0, 1.0, 1.0) - self.assertEqual(new_arg0, arg0) - @skipIfRocm def test_basic_1arg(self): @triton.jit diff --git a/test/inductor/test_static_linkage_utils.py b/test/inductor/test_static_linkage_utils.py deleted file mode 100644 index 0a728c1e66df..000000000000 --- a/test/inductor/test_static_linkage_utils.py +++ /dev/null @@ -1,157 +0,0 @@ -# Owner(s): ["module: inductor"] -from torch.testing._internal.common_utils import run_tests - - -def get_static_linkage_main_cpp_file(): - return """ -#include -#include -#include -#include -#include - -#include -#include -// Include the AOTInductor headers -#include "Minus.wrapper/data/aotinductor/model/Minus.h" -#include "Plus.wrapper/data/aotinductor/model/Plus.h" -#include -#include - -using torch::aot_inductor::AOTInductorModelMinus; -using torch::aot_inductor::AOTInductorModelPlus; -using torch::aot_inductor::ConstantHandle; -using torch::aot_inductor::ConstantMap; - - -int main(int argc, char* argv[]) { - if (argc < 2) { - std::cerr - << "Usage: ./main " - << std::endl; - return 1; - } - std::string path = argv[1]; - std::string device_str = argv[2]; - try { - torch::Device device(device_str); - - // Create two input tensors (10x10) - auto tensor1 = torch::ones({10, 10}, device); - auto tensor2 = torch::ones({10, 10}, device); - // Create two input tensors (10x10) - auto tensor3 = torch::ones({10, 10}, device); - auto tensor4 = torch::ones({10, 10}, device); - - std::vector input_tensors = {tensor1, tensor2}; - std::vector input_tensors2 = {tensor3, tensor4}; - - // Create array of input handles - auto input_handles1 = - torch::aot_inductor::unsafe_alloc_new_handles_from_tensors( - input_tensors); - auto input_handles2 = - torch::aot_inductor::unsafe_alloc_new_handles_from_tensors( - input_tensors2); - - // Create array for output handle - AtenTensorHandle output_handle1; - AtenTensorHandle output_handle2; - - auto constants_map = std::make_shared(); - auto constants_array = std::make_shared>(); - auto model1 = AOTInductorModelPlus::Create( - constants_map, constants_array, device_str, - path + "Plus.wrapper/data/" - "aotinductor/model/"); - model1->load_constants(); - - auto constants_map2 = std::make_shared(); - auto constants_array2 = std::make_shared>(); - auto model2 = AOTInductorModelMinus::Create( - constants_map2, constants_array2, device_str, - path + "Minus.wrapper/data/" - "aotinductor/model/"); - model2->load_constants(); - - // Run the model - torch::aot_inductor::DeviceStreamType stream1 = nullptr; - torch::aot_inductor::DeviceStreamType stream2 = nullptr; - model1->run(&input_handles1[0], &output_handle1, stream1, nullptr); - model2->run(&input_handles2[0], &output_handle2, stream2, nullptr); - - // Convert output handle to tensor - auto output_tensor1 = - torch::aot_inductor::alloc_tensors_by_stealing_from_handles( - &output_handle1, 1); - auto output_tensor2 = - torch::aot_inductor::alloc_tensors_by_stealing_from_handles( - &output_handle2, 1); - - if (!(torch::all(output_tensor1[0] == 2).item())){ - std::cout << "Wrong Output for Plus Model: " << output_tensor1 << std::endl; - throw std::runtime_error("Tensor does not contain only the expected value 2."); - } - if (!(torch::all(output_tensor2[0] == 0).item())){ - std::cout << "Wrong Output for Minus Model: " << output_tensor1 << std::endl; - throw std::runtime_error("Tensor does not contain only the expected value 0."); - } - - return 0; - } catch (const std::exception &e) { - std::cerr << "Error: " << e.what() << std::endl; - return 1; - } -} - -""" - - -def get_static_linkage_makelist_file_cuda(): - return """ -cmake_minimum_required(VERSION 3.10) -project(TestProject) - -set(CMAKE_CXX_STANDARD 17) - -find_package(Torch REQUIRED) -find_package(CUDA REQUIRED) - -add_subdirectory(Plus.wrapper/data/aotinductor/model/) -add_subdirectory(Minus.wrapper/data/aotinductor/model/) - -# Create executable -add_executable(main main.cpp) - -target_compile_definitions(main PRIVATE USE_CUDA) - -target_link_libraries(main PRIVATE torch cuda - ${CUDA_LIBRARIES} - Plus - Minus) -""" - - -def get_static_linkage_makelist_file_cpu(): - return """ -cmake_minimum_required(VERSION 3.10) -project(TestProject) - -set(CMAKE_CXX_STANDARD 17) - -find_package(Torch REQUIRED) - -add_subdirectory(Plus.wrapper/data/aotinductor/model/) -add_subdirectory(Minus.wrapper/data/aotinductor/model/) - -# Create executable -add_executable(main main.cpp) - -target_link_libraries(main PRIVATE torch - Plus - Minus) -""" - - -if __name__ == "__main__": - run_tests() diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 510e827705f7..2b8ace9db4c6 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -30,7 +30,9 @@ ) from torch.testing._internal.common_methods_invocations import op_db, skipOps from torch.testing._internal.common_utils import ( + IS_CI, IS_MACOS, + IS_WINDOWS, IS_X86, skipCUDAMemoryLeakCheckIf, skipIfCrossRef, @@ -67,6 +69,15 @@ sys.exit(0) raise +if IS_WINDOWS and IS_CI: + # TODO(xuhancn) : improve the compiler build performance on windows. + sys.stderr.write( + "This UT is too slow on windows, and will cause out of time in CI. So skip it now.\n" + ) + if __name__ == "__main__": + sys.exit(0) + raise unittest.SkipTest("skip slow test") + bf16 = torch.bfloat16 # not tested f64 = torch.float64 f32 = torch.float32 @@ -355,8 +366,9 @@ def wrapper_noop_set_seed(op, *args, **kwargs): return op(*args, **kwargs) -torch.testing._internal.common_methods_invocations.wrapper_set_seed = ( - wrapper_noop_set_seed +wrapper_noop_set_seed_decorator = patch( + "torch.testing._internal.common_methods_invocations.wrapper_set_seed", + wrapper_noop_set_seed, ) # key can be either op_name, or (op_name, dtype) @@ -969,6 +981,7 @@ def inner(self, device, dtype, op): return inner +@wrapper_noop_set_seed_decorator class TestInductorOpInfo(TestCase): def tearDown(self): torch._dynamo.reset() diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py index 8e9df9e03c84..cf132bea84a5 100644 --- a/test/inductor/test_unbacked_symints.py +++ b/test/inductor/test_unbacked_symints.py @@ -515,6 +515,37 @@ def fn(x): x = torch.tensor([1.0, 0.0, 1.0, 0.0], device=device) torch.compile(fn, fullgraph=True)(x) + @skipGPUIf(not HAS_GPU, "torch.compile for gpu requires triton") + @dynamo_config.patch({"capture_dynamic_output_shape_ops": True}) + def test_unbacked_linear_layer_norm_input(self, device): + class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(387, 128, bias=True, device=device) + self.layer_norm1 = torch.nn.LayerNorm(387, device=device) + self.layer_norm2 = torch.nn.LayerNorm(128, device=device) + + def forward(self, x, mask): + masked_select = x.masked_select(mask) + view = masked_select.view(-1, 387) + + linear = self.linear(view) + layer_norm1 = self.layer_norm1(view) + layer_norm2 = self.layer_norm2(linear) + return linear, layer_norm1, layer_norm2 + + model = MyModel() + inputs = ( + torch.randn((256, 387), dtype=torch.float, device=device), + torch.randint( + low=0, high=2, size=(256, 1), dtype=torch.bool, device=device + ), + ) + + actual = torch.compile(model, fullgraph=True)(*inputs) + expected = model(*inputs) + torch.testing.assert_close(actual, expected) + instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True) diff --git a/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py b/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py deleted file mode 100644 index 2e47e48f140e..000000000000 --- a/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py +++ /dev/null @@ -1,849 +0,0 @@ -# Owner(s): ["module: onnx"] -from __future__ import annotations - -import contextlib -import copy -import dataclasses -import os -import sys -import unittest -from pathlib import Path - -import onnxruntime -from parameterized import parameterized - -import torch -import torch._dynamo.backends.registry -from torch import nn -from torch.onnx import ( - _OrtBackend as OrtBackend, - _OrtBackendOptions as OrtBackendOptions, -) -from torch.testing._internal import common_utils -from torch.testing._internal.common_utils import skipIfNNModuleInlined - - -sys.path.append(str(Path(__file__).absolute().parents[1])) - -import onnx_test_common - - -def make_aot_ort(): - ort_backend = OrtBackend(options=OrtBackendOptions()) - return ort_backend, ort_backend - - -class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime): - def setUp(self): - super().setUp() - torch._dynamo.reset() - OrtBackend.clear_cached_instances() - - def tearDown(self): - super().tearDown() - torch._dynamo.reset() - OrtBackend.clear_cached_instances() - - def test_get_ort_device_type(self): - from onnxruntime.capi import _pybind_state as ORTC - - self.assertEqual( - torch.onnx._internal.onnxruntime._get_ort_device_type("cuda"), - ORTC.OrtDevice.cuda(), - ) - self.assertEqual( - torch.onnx._internal.onnxruntime._get_ort_device_type("cpu"), - ORTC.OrtDevice.cpu(), - ) - self.assertEqual( - torch.onnx._internal.onnxruntime._get_ort_device_type("maia"), - ORTC.OrtDevice.npu(), - ) - - def test_torch_compile_backend_registration(self): - self.assertIn("onnxrt", torch._dynamo.backends.registry.list_backends()) - backend = torch._dynamo.backends.registry.lookup_backend("onnxrt") - self.assertEqual(backend.__module__, "torch.onnx._internal.onnxruntime") - - def _test_torch_compile_backend_caching_assert_reused( - self, options: OrtBackendOptions - ): - self.assertFalse(OrtBackend.get_cached_instances()) # assert setUp/tearDown - new_backend = OrtBackend.get_cached_instance_for_options(options) - reused_backend = OrtBackend.get_cached_instance_for_options(options) - self.assertEqual(len(OrtBackend.get_cached_instances()), 1) - self.assertIs(reused_backend, new_backend) - if options is None or options.ort_session_options is None: - # OrtBackendOptions.ort_session_options is a pybind11 object that - # cannot be pickled via dataclasses.asdict - self.assertEqual( - new_backend, - OrtBackend.get_cached_instance_for_options( - dataclasses.asdict(options) if options else None - ), - ) - - @parameterized.expand( - [ - (None,), - (OrtBackendOptions(),), - (OrtBackendOptions(use_aot_autograd=True),), - (OrtBackendOptions(use_aot_autograd=False),), - (OrtBackendOptions(preallocate_output=True),), - (OrtBackendOptions(preallocate_output=False),), - (OrtBackendOptions(infer_execution_providers=True),), - (OrtBackendOptions(infer_execution_providers=False),), - (OrtBackendOptions(preferred_execution_providers=["A", "B", "C"]),), - ( - OrtBackendOptions( - preferred_execution_providers=["A", "B", ("C", {"option": "value"})] - ), - ), - (OrtBackendOptions(default_execution_providers=["Something"]),), - (OrtBackendOptions(),), - ] - ) - def test_torch_compile_backend_caching_assert_reused( - self, options: OrtBackendOptions - ): - self._test_torch_compile_backend_caching_assert_reused(options) - - @parameterized.expand( - [ - (OrtBackendOptions(ort_session_options=onnxruntime.SessionOptions()),), - ] - ) - def test_torch_compile_backend_caching_assert_not_reused( - self, options: OrtBackendOptions - ): - with self.assertRaises(AssertionError): - self._test_torch_compile_backend_caching_assert_reused(options) - - def _test_model_numerically( - self, - model, - dynamo_backend, - example_args_collection, - fullgraph: bool = False, - test_backward: bool = False, - atol: float = 1e-5, - rtol: float = 1e-6, - ): - """Run original and compiled model and compare the results. - - Args: - model: The model to test. - dynamo_backend: The dynamo backend to use. Here we use string `onnxrt` or - the first returned value of `make_aot_ort()`. - example_args_collection: A tuple of example arguments to test. E.g., - ( - (torch.randn(2), torch.randn(2)), - (torch.randn(4), torch.randn(4)), - ) - if you want to test - model(torch.randn(2), torch.randn(2)) and - model(torch.randn(4), torch.randn(4)) - . - """ - compiled_model = torch.compile( - model if not isinstance(model, torch.nn.Module) else copy.deepcopy(model), - backend=dynamo_backend, - dynamic=True, - fullgraph=fullgraph, - ) - - for example_args in example_args_collection: - baseline_result = model(*example_args) - result = compiled_model(*example_args) - if isinstance(baseline_result, torch.Tensor): - torch.testing.assert_close( - baseline_result, result, atol=atol, rtol=rtol - ) - if test_backward: - baseline_result.sum().backward() - result.sum().backward() - for baseline_param, param in zip( - model.parameters(), compiled_model.parameters() - ): - torch.testing.assert_close( - baseline_param.grad, param.grad, atol=atol, rtol=rtol - ) - else: - assert test_backward is False, ( - "Calculating backward with multiple outputs is not supported yet." - ) - for baseline_elem, result_elem in zip(baseline_result, result): - torch.testing.assert_close( - baseline_elem, result_elem, atol=atol, rtol=rtol - ) - - def _assert_counting_information( - self, - ort_backend: OrtBackend, - # Number of session runs. - # If there is no graph break, this should be the same as - # total number of forward calls. - expected_execution_count: int, - # Number of GraphModule's cached. - # With one graph break, a model will be mapped - # to two GraphModule's. - number_of_cached_graph_modules: int, - # Number of ONNX models cached for each GraphModule, - # number_of_exported_onnx_models[i] contains # of ONNX models exported from - # the i-th element (type: torch.fx.GraphModule) in - # OrtBackend._all_ort_execution_info.execution_info_per_graph_module.values(). - number_of_exported_onnx_models_for_all_graph_modules: tuple[int, ...], - ): - self.assertEqual(expected_execution_count, ort_backend.execution_count) - self.assertEqual( - len(ort_backend._all_ort_execution_info.execution_info_per_graph_module), - number_of_cached_graph_modules, - ) - self.assertEqual( - len(ort_backend._all_ort_execution_info.execution_info_per_graph_module), - len(number_of_exported_onnx_models_for_all_graph_modules), - ) - for ( - onnx_info, - expected_number_of_onnx_models, - ) in zip( - ort_backend._all_ort_execution_info.execution_info_per_graph_module.values(), - number_of_exported_onnx_models_for_all_graph_modules, - ): - self.assertEqual(len(onnx_info), expected_number_of_onnx_models) - - def _assert_dynamic_input_and_output_shapes_in_all_onnx_models(self, backend): - for ( - onnx_session_infos - ) in backend._all_ort_execution_info.execution_info_per_graph_module.values(): - for onnx_session_info in onnx_session_infos: - inputs_have_dynamic_shapes = False - for input in onnx_session_info.input_value_infos: - if hasattr(input.type, "tensor_type") and hasattr( - input.type.tensor_type, "shape" - ): - for dim in input.type.tensor_type.shape.dim: - inputs_have_dynamic_shapes = ( - inputs_have_dynamic_shapes or hasattr(dim, "dim_param") - ) - output_have_dynamic_shapes = False - for output in onnx_session_info.output_value_infos: - if hasattr(output.type, "tensor_type") and hasattr( - output.type.tensor_type, "shape" - ): - for dim in output.type.tensor_type.shape.dim: - output_have_dynamic_shapes = ( - output_have_dynamic_shapes or hasattr(dim, "dim_param") - ) - self.assertTrue(inputs_have_dynamic_shapes) - self.assertTrue(output_have_dynamic_shapes) - - @parameterized.expand( - [ - (True,), - (False,), - ] - ) - def test_elementwise_function_single_output(self, test_local_backend: bool): - example_args_collection = tuple( - (torch.randn(batch, dtype=torch.float32),) for batch in (2, 4, 6, 8, 10) - ) - - def elementwise_model(x: torch.Tensor): - y = x.relu() - z = y.sigmoid() - return z - - if test_local_backend: - local_aot_ort, local_ort = make_aot_ort() - else: - # This will use the global ONNXRuntime backend registered - # in Dynamo to compile the tested model. - local_aot_ort, local_ort = "onnxrt", None - - self._test_model_numerically( - elementwise_model, - local_aot_ort, - example_args_collection, - ) - - # We can only check local backend's counting information - # since global backend's counting information comes from - # all compiled models. - if test_local_backend: - assert local_ort is not None - self._assert_counting_information( - local_ort, - # OrtBackend._ort_acclerated_call should have been called 5 times because - # we have 5 different batch sizes to test. - expected_execution_count=len(example_args_collection), - # Since this local_ort only compiled one function, - # there should be only one GraphModule in its cached. - number_of_cached_graph_modules=1, - # Since dynamic shape is enabled, we should only have one ONNX model - # to support different batch sizes. - number_of_exported_onnx_models_for_all_graph_modules=(1,), - ) - - @parameterized.expand( - [ - (True,), - (False,), - ] - ) - def test_elementwise_function_multiple_output(self, test_local_backend: bool): - example_args_collection = tuple( - (torch.randn(batch, dtype=torch.float32),) for batch in (2, 4, 8) - ) - - def elementwise_model_with_multiple_outputs(w: torch.Tensor): - x = w + w - y = x.relu() - z = y * y - return x, y, z - - if test_local_backend: - local_aot_ort, local_ort = make_aot_ort() - else: - local_aot_ort, local_ort = "onnxrt", None - - self._test_model_numerically( - elementwise_model_with_multiple_outputs, - local_aot_ort, - example_args_collection, - ) - - if test_local_backend: - assert local_ort is not None - self._assert_counting_information( - local_ort, - expected_execution_count=len(example_args_collection), - number_of_cached_graph_modules=1, - number_of_exported_onnx_models_for_all_graph_modules=(1,), - ) - - @parameterized.expand( - [ - (True,), - (False,), - ] - ) - def test_mlp_with_local_backend(self, test_local_backend: bool): - example_args_collection = tuple( - (torch.randn(batch, 2, dtype=torch.float32),) for batch in (1, 2, 4, 6, 8) - ) - - class MLP(nn.Module): - def __init__(self) -> None: - super().__init__() - self.fc1 = nn.Linear(2, 4, bias=True) - self.fc2 = nn.Linear(4, 2, bias=True) - - def forward(self, tensor_x: torch.Tensor): - tensor_x = self.fc1(tensor_x) - tensor_x = torch.sigmoid(tensor_x) - tensor_x = self.fc2(tensor_x) - tensor_x = torch.sigmoid(tensor_x) - return tensor_x - - if test_local_backend: - local_aot_ort, local_ort = make_aot_ort() - else: - local_aot_ort, local_ort = "onnxrt", None - - self._test_model_numerically( - MLP(), - local_aot_ort, - example_args_collection, - ) - - if test_local_backend: - assert local_ort is not None - self._assert_counting_information( - local_ort, - # OrtBackend._ort_acclerated_call should have been called 5 times because - # we have 5 different batch sizes to test. - expected_execution_count=len(example_args_collection), - # Since this local_ort only compiled one function, there should be only two - # GraphModule's in its cached. One for batch sizes 2, 4, 6, 8 and the other - # for batch size 1. - number_of_cached_graph_modules=2, - # Since dynamic shape is enabled, we should only have one ONNX model - # to support different batch sizes. - number_of_exported_onnx_models_for_all_graph_modules=(1, 1), - ) - - @parameterized.expand( - [ - (True, True), - (True, False), - ] - ) - @skipIfNNModuleInlined("https://github.com/pytorch/pytorch/issues/129456") - def test_llama_attention_with_local_backend( - self, test_local_backend: bool, test_backward: bool - ): - from transformers import LlamaConfig # noqa: F811 - from transformers.models.llama.modeling_llama import ( # noqa: F811 - LlamaAttention, - ) - - hidden_size = 16 - - config = LlamaConfig( - num_hidden_layers=1, - vocab_size=1024, - hidden_size=hidden_size, - intermediate_size=16, - max_position_embeddings=256, - num_attention_heads=2, - hidden_dropout_prob=0.0, - attention_dropout_prob=0.0, - ) - - class LlamaAttentionWrapper(torch.nn.Module): - def __init__(self, config): - super().__init__() - try: - # New version of LlamaAttention has layer_idx argument. - self.attention = LlamaAttention(config, layer_idx=0) - except TypeError: - # Fall back to old version of LlamaAttention. - self.attention = LlamaAttention(config) - - def forward(self, hidden_states, attention_mask, position_ids): - attn_output, _, _ = self.attention( - hidden_states, attention_mask, position_ids - ) - return attn_output - - def generate_example_inputs(batch: int, seq: int, hidden_size: int): - # shape: batch x seq x hidden_size - hidden_state = torch.randn(batch, seq, hidden_size) - # [0.0000e+00, ..., 0.0000e+00, -3.4028e+38, ...] - # shape: batch x 1 x seq x seq - attention_mask = torch.zeros(batch, 1, seq, seq, dtype=torch.float) - position_ids = torch.arange(0, seq, dtype=torch.int64) - position_ids = position_ids.unsqueeze(0).view(-1, seq) - - return hidden_state, attention_mask, position_ids - - # Reason for using multiple example argument groups: - # Export model to ONNX with one example argument group - # and test it with other example argument groups. - example_args_collection = ( - generate_example_inputs(2, 8, hidden_size), - generate_example_inputs(4, 7, hidden_size), - generate_example_inputs(9, 15, hidden_size), - ) - - if test_local_backend: - local_aot_ort, local_ort = make_aot_ort() - else: - local_aot_ort, local_ort = "onnxrt", None - - model = LlamaAttentionWrapper(config).eval() - - self._test_model_numerically( - model, - local_aot_ort, - example_args_collection, - fullgraph=True, - test_backward=test_backward, - ) - - if test_local_backend: - assert local_ort is not None - number_of_captured_graphs = 2 if test_backward else 1 - - execution_count = len(example_args_collection) * number_of_captured_graphs - self._assert_counting_information( - local_ort, - # Number of InferenceSession runs. - expected_execution_count=execution_count, - # Number of GraphModule's seen by ORT. - number_of_cached_graph_modules=number_of_captured_graphs, - # Number of InferenceSession's created per GraphModule. - number_of_exported_onnx_models_for_all_graph_modules=(1,) - * number_of_captured_graphs, - ) - self._assert_dynamic_input_and_output_shapes_in_all_onnx_models(local_ort) - - @parameterized.expand( - [ - (True, False), - (True, True), - ] - ) - @skipIfNNModuleInlined("https://github.com/pytorch/pytorch/issues/129456") - def test_llama_decoder_with_local_backend( - self, test_local_backend: bool, test_backward: bool - ): - from transformers import LlamaConfig # noqa: F811 - from transformers.models.llama.modeling_llama import ( # noqa: F811 - LlamaDecoderLayer, - ) - - hidden_size = 16 - - config = LlamaConfig( - num_hidden_layers=1, - vocab_size=1024, - hidden_size=hidden_size, - intermediate_size=16, - max_position_embeddings=256, - num_attention_heads=2, - hidden_dropout_prob=0.0, - attention_dropout_prob=0.0, - ) - - class LlamaDecoderWrapper(torch.nn.Module): - def __init__(self, config): - super().__init__() - try: - # New version of LlamaDecoderLayer has layer_idx argument. - self.decoder = LlamaDecoderLayer(config, layer_idx=0) - except TypeError: - # Fall back to old version of LlamaDecoderLayer. - self.decoder = LlamaDecoderLayer(config) - - def forward(self, hidden_states, attention_mask, position_ids): - (decoder_output,) = self.decoder( - hidden_states, attention_mask, position_ids - ) - return decoder_output - - def generate_example_inputs(batch: int, seq: int, hidden_size: int): - # shape: batch x seq x hidden_size - hidden_state = torch.randn(batch, seq, hidden_size) - # [0.0000e+00, ..., 0.0000e+00, -3.4028e+38, ...] - # shape: batch x 1 x seq x seq - attention_mask = torch.zeros(batch, 1, seq, seq, dtype=torch.float) - position_ids = torch.arange(0, seq, dtype=torch.int64) - position_ids = position_ids.unsqueeze(0).view(-1, seq) - return hidden_state, attention_mask, position_ids - - # Reason for using multiple example argument groups: - # Export model to ONNX with one example argument group - # and test it with other example argument groups. - example_args_collection = ( - generate_example_inputs(2, 8, hidden_size), - generate_example_inputs(4, 7, hidden_size), - generate_example_inputs(9, 15, hidden_size), - ) - - if test_local_backend: - local_aot_ort, local_ort = make_aot_ort() - else: - local_aot_ort, local_ort = "onnxrt", None - - model = LlamaDecoderWrapper(config).eval() - - self._test_model_numerically( - model, - local_aot_ort, - example_args_collection, - fullgraph=True, - test_backward=test_backward, - ) - - if test_local_backend: - assert local_ort is not None - number_of_captured_graphs = 2 if test_backward else 1 - - execution_count = len(example_args_collection) * number_of_captured_graphs - - self._assert_counting_information( - local_ort, - expected_execution_count=execution_count, - number_of_cached_graph_modules=number_of_captured_graphs, - number_of_exported_onnx_models_for_all_graph_modules=(1,) - * number_of_captured_graphs, - ) - self._assert_dynamic_input_and_output_shapes_in_all_onnx_models(local_ort) - - @parameterized.expand( - [ - (True, False), - (True, True), - ] - ) - @skipIfNNModuleInlined("https://github.com/pytorch/pytorch/issues/129456") - def test_llama_with_local_backend( - self, test_local_backend: bool, test_backward: bool - ): - from transformers import LlamaConfig # noqa: F811 - from transformers.models.llama.modeling_llama import LlamaModel # noqa: F811 - - config = LlamaConfig( - num_hidden_layers=1, - vocab_size=1024, - hidden_size=16, - intermediate_size=16, - max_position_embeddings=256, - num_attention_heads=2, - hidden_dropout_prob=0.0, - attention_dropout_prob=0.0, - ) - - config._attn_implementation = "eager" - - class LlamaModelWrapper(torch.nn.Module): - def __init__(self, config): - super().__init__() - self.llama = LlamaModel(config) - - def forward(self, input_ids, attention_mask, position_ids): - decoder_output = self.llama( - input_ids, attention_mask, position_ids, return_dict=False - ) - return decoder_output[0] - - def generate_example_inputs(batch: int, seq: int): - # shape: batch x seq x hidden_size - input_ids = torch.randint(0, 7, size=(batch, seq), dtype=torch.int64) - # Usually, its shape is a tensor with shape batch x seq x seq. - # However, to bypass some control flow in the model, we use None. - attention_mask = None - position_ids = torch.arange(0, seq, dtype=torch.int64) - position_ids = position_ids.unsqueeze(0).view(-1, seq) - return input_ids, attention_mask, position_ids - - # Reason for using multiple example argument groups: - # Export model to ONNX with one example argument group - # and test it with other example argument groups. - example_args_collection = ( - generate_example_inputs(2, 8), - generate_example_inputs(4, 7), - generate_example_inputs(9, 15), - ) - - if test_local_backend: - local_aot_ort, local_ort = make_aot_ort() - else: - local_aot_ort, local_ort = "onnxrt", None - - model = LlamaModelWrapper(config).eval() - - self._test_model_numerically( - model, - local_aot_ort, - example_args_collection, - fullgraph=True, - test_backward=test_backward, - atol=1e-4, - rtol=1e-4, - ) - - if test_local_backend: - assert local_ort is not None - number_of_captured_graphs = 2 if test_backward else 1 - execution_count = len(example_args_collection) * number_of_captured_graphs - self._assert_counting_information( - local_ort, - expected_execution_count=execution_count, - number_of_cached_graph_modules=number_of_captured_graphs, - number_of_exported_onnx_models_for_all_graph_modules=(1,) - * number_of_captured_graphs, - ) - self._assert_dynamic_input_and_output_shapes_in_all_onnx_models(local_ort) - - @parameterized.expand( - [ - (True,), - (False,), - ] - ) - def test_dump_model(self, test_local_backend: bool): - @contextlib.contextmanager - def onnxrt_dump_path(path): - key = "ONNXRT_DUMP_PATH" - before = os.environ.get(key, None) - os.environ[key] = path - yield - if before is None: - del os.environ[key] - else: - os.environ[key] = before - - example_args_collection = tuple( - (torch.randn(batch, 2, dtype=torch.float32),) for batch in (1, 2, 4, 6, 8) - ) - - class MLP(nn.Module): - def __init__(self) -> None: - super().__init__() - self.fc1 = nn.Linear(2, 4, bias=True) - self.fc2 = nn.Linear(4, 2, bias=True) - - def forward(self, tensor_x: torch.Tensor): - tensor_x = self.fc1(tensor_x) - tensor_x = torch.sigmoid(tensor_x) - tensor_x = self.fc2(tensor_x) - tensor_x = torch.sigmoid(tensor_x) - return tensor_x - - if test_local_backend: - local_aot_ort, _ = make_aot_ort() - else: - local_aot_ort, _ = "onnxrt", None - - prefix = f"test_dump_model_{'local' if test_local_backend else 'onnxrt'}_" - expected = f"{prefix}0.onnx" - expected_graph = f"{prefix}0.txt" - if os.path.exists(expected): - os.remove(expected) - if os.path.exists(expected_graph): - os.remove(expected_graph) - not_expected = f"{prefix}1.onnx" - self.assertFalse(os.path.exists(not_expected)) - - model = MLP() - compiled_model = torch.compile( - model if not isinstance(model, torch.nn.Module) else copy.deepcopy(model), - backend=local_aot_ort, - dynamic=True, - ) - - self.assertFalse(os.path.exists(expected)) - self.assertFalse(os.path.exists(not_expected)) - - with onnxrt_dump_path(prefix): - example_args = example_args_collection[0] - compiled_model(*example_args) - self.assertTrue(os.path.exists(expected)) - self.assertTrue(os.path.exists(expected_graph)) - self.assertFalse(os.path.exists(not_expected)) - - compiled_model(*example_args) - self.assertTrue(os.path.exists(expected)) - self.assertFalse(os.path.exists(not_expected)) - - @unittest.skipIf(not torch.cuda.is_available(), "No CUDA to run mix devicei nputs") - def test_mix_device_inputs(self): - data = torch.randn(4, 8, device="cuda") - ref_data = torch.randn(8, 4, device="cpu") - - def reshape_wrapper(data, ref_cpu_data): - # Dummy line to make sure ref_cpu_data - # is included in the captured graph. - ref_cpu_data += 1 - shape = ref_cpu_data.shape - # A call with GPU and CPU inputs. - return torch.reshape(data, shape) - - compiled_model = torch.compile( - reshape_wrapper, - backend="onnxrt", - dynamic=True, - ) - - result = compiled_model(data, ref_data) - - self.assertTrue(torch.allclose(result, data.view(ref_data.shape))) - - def test_no_input(self): - def reshape_wrapper(): - # A model without input. - ones = torch.ones(4, 8) - zeros = torch.zeros(4, 8) - return ones + zeros - - recorded_models = [] - - def record_onnx_model_transform(onnx_model): - # Record the ONNX model seen by the transform. - recorded_models.append(onnx_model) - - compiled_model = torch.compile( - reshape_wrapper, - backend="onnxrt", - dynamic=True, - options=torch.onnx._OrtBackendOptions( - pre_ort_model_transforms=[ - record_onnx_model_transform, - ] - ), - ) - - result = compiled_model() - - self.assertEqual(len(recorded_models), 1) - # NOTE: Constant folded by optimizer - self.assertTrue( - "Constant" in [node.op_type for node in recorded_models[0].graph.node] - ) - - self.assertEqual(result, torch.ones(4, 8)) - - def test_custom_onnx_transform(self): - # This test consists of 2 parts: - # 1. If a registered ONNX transform is called and recorded a model. - # 2. If a registered ONNX transform is called and changed the model - - # Part 1: Record the ONNX model seen by the transform. - # This list contains the models recorded by record_onnx_model_transform. - recorded_models = [] - - def record_onnx_model_transform(onnx_model): - # Record the ONNX model seen by the transform. - recorded_models.append(onnx_model) - - def example_model(x: torch.Tensor): - y = torch.sigmoid(x) - z = x + y - return z - - compiled_model = torch.compile( - example_model, - backend="onnxrt", - dynamic=True, - options=torch.onnx._OrtBackendOptions( - pre_ort_model_transforms=[record_onnx_model_transform] - ), - ) - - x = torch.randn(2) - assert len(recorded_models) == 0 - y = compiled_model(x) - assert len(recorded_models) == 1 - - # Part 2: Change the ONNX model seen by the transform so that - # ORT receives a different model. - # NOTE: the function is optimized away by optimizer - def replace_relu_with_sigmoid(onnx_model): - for node in onnx_model.graph.node: - if node.op_type == "Relu": - node.op_type = "Sigmoid" - - def another_example_model(x: torch.Tensor): - y = torch.relu(x) - z = x + y - return z - - another_compiled = torch.compile( - another_example_model, - backend="onnxrt", - dynamic=True, - options=torch.onnx._OrtBackendOptions( - pre_ort_model_transforms=[ - replace_relu_with_sigmoid, - record_onnx_model_transform, - ] - ), - ) - - another_y = another_compiled(x) - # We have 2 models recorded `record_onnx_model_transform` - # by the 2 torch.compile calls above. - assert len(recorded_models) == 2 - # Since we have changed "Relu" to "Sigmoid" in replace_sigmoid_with_relu, - # the result should be the same to previous y. - torch.testing.assert_close(y, another_y) - # another_example_model still uses "Relu", so the result should be different - # than y. - self.assertFalse(torch.allclose(y, another_example_model(x))) - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py index 993da8305868..9a8a171b5fe2 100644 --- a/test/onnx/exporter/test_api.py +++ b/test/onnx/exporter/test_api.py @@ -281,10 +281,11 @@ def forward(self, input): # Use GELU activation function return torch.nn.functional.gelu(input, approximate="tanh") - input = torch.randn(1, 3, 4, 4) + input = (torch.randn(1, 3, 4, 4),) onnx_program_op18 = torch.onnx.export( GeluModel(), input, + opset_version=18, dynamo=True, ) all_nodes_op18 = [n.op_type for n in onnx_program_op18.model.graph] diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index d90d77b54786..4a1cb8f45814 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -3550,14 +3550,15 @@ def test_wrapped_fbgemm_linear_fp16(self): (2, 4), # batch_size (4, 5), # input_channels (4, 7), # output_channels + (True, False), # bias None or not ) - for batch_size, input_channels, output_channels in options: + for batch_size, input_channels, output_channels, bias_is_none in options: pack_op = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16 linear_op = torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight x = torch.randn(batch_size, input_channels) w = torch.randn(output_channels, input_channels) - bias = torch.randn(output_channels) + bias = torch.randn(output_channels) if not bias_is_none else None w_packed = pack_op(w) out = linear_op(x, w_packed, bias, output_channels) @@ -3591,6 +3592,18 @@ def func(X, W, B): self.assertEqual(ref_out, compiled_out) + def func(X, W): + packed_W = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(W) + return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(X, packed_W, None, W.size(0)) + + ref_out = func(x, w) + + compiled = torch.compile(func) + compiled_out = compiled(x, w) + + self.assertEqual(ref_out, compiled_out) + + """Tests the correctness of the dynamic quantized lstm/gru.""" def _get_rnn_inputs(self, seq_len, num_batches, input_size, hidden_size, num_directions, reduce_range): diff --git a/test/run_test.py b/test/run_test.py index 7f810c039c7f..c63fb64e8f05 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -28,6 +28,7 @@ from torch.testing._internal.common_utils import ( get_report_path, IS_CI, + IS_LINUX, IS_MACOS, retry_shell, set_cwd, @@ -35,7 +36,6 @@ TEST_CUDA, TEST_SAVE_XML, TEST_WITH_ASAN, - TEST_WITH_CROSSREF, TEST_WITH_ROCM, TEST_WITH_SLOW_GRADCHECK, ) @@ -913,8 +913,12 @@ def _test_autoload(test_directory, options, enable=True): def run_test_with_openreg(test_module, test_directory, options): + # TODO(FFFrog): Will remove this later when windows/macos are supported. + if not IS_LINUX: + return 0 + openreg_dir = os.path.join( - test_directory, "cpp_extensions", "open_registration_extension" + test_directory, "cpp_extensions", "open_registration_extension", "torch_openreg" ) install_dir, return_code = install_cpp_extensions(openreg_dir) if return_code != 0: @@ -1405,11 +1409,6 @@ def parse_args(): action="store_true", help="Enables removing tests based on TD", default=IS_CI - and ( - TEST_WITH_CROSSREF - or TEST_CONFIG == "distributed" - or TEST_CONFIG == "default" - ) and get_pr_number() is not None and not strtobool(os.environ.get("NO_TD", "False")) and not IS_MACOS @@ -1587,6 +1586,7 @@ def get_selected_tests(options) -> list[str]: "test_nn", "inductor/test_mps_basic", "inductor/test_torchinductor", + "inductor/test_aot_inductor", ] else: # Exclude all mps tests otherwise diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index d671e3f874c9..06b681bee981 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -322,12 +322,15 @@ def test_jit_cuda_archflags(self): [f"{capability[0]}{capability[1]}" for capability in capabilities], None, ), - "Maxwell+Tegra;6.1": (["53", "61"], None), - "Volta": (["70"], ["70"]), } archflags["7.5+PTX"] = (["75"], ["75"]) - archflags["5.0;6.0+PTX;7.0;7.5"] = (["50", "60", "70", "75"], ["60"]) - if int(torch.version.cuda.split(".")[0]) < 12: + major, minor = map(int, torch.version.cuda.split(".")[:2]) + if major < 12 or (major == 12 and minor <= 9): + # Compute capability <= 7.0 is only supported up to CUDA 12.9 + archflags["Maxwell+Tegra;6.1"] = (["53", "61"], None) + archflags["Volta"] = (["70"], ["70"]) + archflags["5.0;6.0+PTX;7.0;7.5"] = (["50", "60", "70", "75"], ["60"]) + if major < 12: # CUDA 12 drops compute capability < 5.0 archflags["Pascal 3.5"] = (["35", "60", "61"], None) diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index dc44f66bcebc..d6d7ad3dc467 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -3,7 +3,7 @@ import os import unittest -import pytorch_openreg # noqa: F401 +import torch_openreg # noqa: F401 import torch import torch.testing._internal.common_utils as common diff --git a/test/test_cuda.py b/test/test_cuda.py index aec308101461..e4b5cf51b6f7 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -6484,12 +6484,10 @@ def test_cuda_autocast_deprecated_warning(self): with torch.cuda.amp.autocast(): _ = torch.ones(10) - def test_cuda_module_loading_env(self): - torch.cuda.init() - val = os.environ.get("CUDA_MODULE_LOADING", "") - self.assertEqual(val, "LAZY") - +@unittest.skipIf( + os.environ.get("USE_LEGACY_DRIVER", None) == "1", "Doesn't work with older driver" +) class TestCompileKernel(TestCase): @unittest.skipIf(TEST_WITH_ROCM, "ROCM does not support nvrtc") @unittest.skipIf(not TEST_CUDA, "No CUDA") diff --git a/test/test_decomp.py b/test/test_decomp.py index 5d641e32e422..dcd6e69af997 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -15,7 +15,7 @@ from torch._export.utils import _is_cia_op from torch._ops import DispatchKey from torch.testing import make_tensor -from torch.testing._internal.common_cuda import tf32_off +from torch.testing._internal.common_cuda import SM70OrLater, tf32_off from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCPU, @@ -1226,6 +1226,33 @@ def f(x, w, b): for o_ref, o in zip(out_ref, out): self.assertEqual(o_ref.dtype, o.dtype) + @onlyCUDA + @unittest.skipIf(not SM70OrLater, "triton") + def test_rms_norm_decomp_cuda(self, device): + @torch.compile + def rms_norm_sinh(a, b, c): + output = torch.nn.functional.rms_norm(a, b, c) + return torch.sinh(output) + + normalized_shape_arg = (3, 3, 3) + input_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True) + weight_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True) + + def forward_pass_fn(): + return rms_norm_sinh(input_tensor, normalized_shape_arg, weight_tensor) + + model_output, generated_codes = torch._inductor.utils.run_fw_bw_and_get_code( + forward_pass_fn + ) + + # check RMSNorm was fused with sinh + self.assertTrue( + "triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0] + ) + self.assertTrue( + "triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1] + ) + instantiate_device_type_tests(DecompOneOffTests, globals()) diff --git a/test/test_dlpack.py b/test/test_dlpack.py index 389f63efa687..f3272cc69476 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -3,10 +3,13 @@ import torch from torch.testing import make_tensor from torch.testing._internal.common_device_type import ( + deviceCountAtLeast, dtypes, instantiate_device_type_tests, + onlyCPU, onlyCUDA, onlyNativeDeviceTypes, + skipCUDAIfNotRocm, skipCUDAIfRocm, skipMeta, ) @@ -17,7 +20,7 @@ skipIfTorchDynamo, TestCase, ) -from torch.utils.dlpack import from_dlpack, to_dlpack +from torch.utils.dlpack import DLDeviceType, from_dlpack, to_dlpack # Wraps a tensor, exposing only DLPack methods: @@ -241,25 +244,81 @@ def test_dlpack_tensor_invalid_stream(self, device, dtype): x = make_tensor((5,), dtype=dtype, device=device) x.__dlpack__(stream=object()) + @skipMeta + @onlyCUDA + @skipCUDAIfRocm + def test_dlpack_cuda_per_thread_stream(self, device): + # Test whether we raise an error if we are trying to use per-thread default + # stream, which is currently not supported by PyTorch. + x = make_tensor((5,), dtype=torch.float32, device=device) + with self.assertRaisesRegex( + BufferError, "per-thread default stream is not supported" + ): + x.__dlpack__(stream=2) + + @skipMeta + @onlyCUDA + @skipCUDAIfNotRocm + def test_dlpack_invalid_rocm_streams(self, device): + # Test that we correctly raise errors on unsupported ROCm streams. + def test(x, stream): + with self.assertRaisesRegex( + AssertionError, r"unsupported stream on ROCm: \d" + ): + x.__dlpack__(stream=stream) + + x = make_tensor((5,), dtype=torch.float32, device=device) + test(x, stream=1) + test(x, stream=2) + + @skipMeta + @onlyCUDA + @skipCUDAIfRocm + def test_dlpack_invalid_cuda_streams(self, device): + x = make_tensor((5,), dtype=torch.float32, device=device) + with self.assertRaisesRegex(AssertionError, r"unsupported stream on CUDA: \d"): + x.__dlpack__(stream=0) + + @skipMeta + def test_dlpack_invalid_cpu_stream(self): + x = make_tensor((5,), dtype=torch.float32, device="cpu") + with self.assertRaisesRegex(AssertionError, r"stream should be None on cpu."): + x.__dlpack__(stream=0) + + @skipMeta + @onlyCUDA + @deviceCountAtLeast(2) + def test_dlpack_tensor_on_different_device(self, devices): + dev0, dev1 = devices[:2] + + with torch.device(dev0): + x = make_tensor((5,), dtype=torch.float32, device=dev0) + + with self.assertRaisesRegex( + BufferError, r"Can't export tensors on a different CUDA device" + ): + with torch.device(dev1): + x.__dlpack__() + # TODO: add interchange tests once NumPy 1.22 (dlpack support) is required @skipMeta def test_dlpack_export_requires_grad(self): x = torch.zeros(10, dtype=torch.float32, requires_grad=True) - with self.assertRaisesRegex(RuntimeError, r"require gradient"): + with self.assertRaisesRegex(BufferError, r"require gradient"): x.__dlpack__() @skipMeta def test_dlpack_export_is_conj(self): x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) y = torch.conj(x) - with self.assertRaisesRegex(RuntimeError, r"conjugate bit"): + with self.assertRaisesRegex(BufferError, r"conjugate bit"): y.__dlpack__() @skipMeta def test_dlpack_export_non_strided(self): x = torch.sparse_coo_tensor([[0]], [1], size=(1,)) y = torch.conj(x) - with self.assertRaisesRegex(RuntimeError, r"strided"): + with self.assertRaisesRegex(BufferError, r"strided"): y.__dlpack__() @skipMeta @@ -317,6 +376,112 @@ def test(device, **kwargs): # Consumer should still be able to process a smaller version capsule. test(device, max_version=(2, 0)) + @skipMeta + @onlyCPU + @dtypes( + # Note: NumPy DLPack bool support only landed in 1.25. + *all_types_and_complex_and( + torch.half, + torch.uint16, + torch.uint32, + torch.uint64, + ) + ) + def test_numpy_dlpack_protocol_conversion(self, device, dtype): + import numpy as np + + t = make_tensor((5,), dtype=dtype, device=device) + + if hasattr(np, "from_dlpack"): + # DLPack support only available from NumPy 1.22 onwards. + # Here, we test having another framework (NumPy) calling our + # Tensor.__dlpack__ implementation. + arr = np.from_dlpack(t) + self.assertEqual(t, arr) + + # We can't use the array created above as input to from_dlpack. + # That's because DLPack imported NumPy arrays are read-only. + # Thus, we need to convert it to NumPy by using the numpy() method. + t_arr = t.numpy() + + # Transform the NumPy array back using DLPack. + res = from_dlpack(t_arr) + + self.assertEqual(t, res) + self.assertEqual(t.data_ptr(), res.data_ptr()) + + def _test_from_dlpack(self, device, out_device=None, copy=None): + if isinstance(device, str): + device = torch.device(device) + + inp = make_tensor((5,), dtype=torch.float32, device=device) + out = torch.from_dlpack(inp, device=out_device, copy=copy) + + if out_device is None: + out_device = device + if isinstance(out_device, str): + out_device = torch.device(out_device) + + self.assertEqual(inp, out) + self.assertEqual(out.device, out_device) + + # They should be moved (i.e. not copied) only if: + # (a) we are forcing move, i.e. copy=False + # (b) the output device is the same as the input one AND copy is None + if copy is False or (copy is None and device == out_device): + self.assertEqual(inp.data_ptr(), out.data_ptr()) + else: + # Otherwise, inp should be copied. + self.assertNotEqual(inp.data_ptr(), out.data_ptr()) + + @skipMeta + @onlyCUDA + def test_copy(self, device): + # Force-copy same device tensor. + self._test_from_dlpack(device, copy=True) + self._test_from_dlpack(device, out_device=device, copy=True) + # Output should be in a different device, i.e. should have been copied. + self._test_from_dlpack(device, out_device="cpu") + self._test_from_dlpack(device, out_device="cpu", copy=True) + + @skipMeta + @onlyCUDA + def test_no_copy(self, device): + # No copy, since tensor lives in the same device. + self._test_from_dlpack(device) + self._test_from_dlpack(device, copy=False) + self._test_from_dlpack(device, out_device=device) + self._test_from_dlpack(device, out_device=device, copy=False) + + @skipMeta + @onlyCUDA + def test_needs_copy_error(self, device): + with self.assertRaisesRegex(ValueError, r"cannot move .* tensor from .*"): + self._test_from_dlpack(device, out_device="cpu", copy=False) + + @skipMeta + @onlyNativeDeviceTypes + def test_unsupported_device_error(self, device): + inp = make_tensor((5,), dtype=torch.float32, device=device) + dl_device_type = DLDeviceType.kDLHexagon + + with self.assertRaisesRegex( + BufferError, f"Unsupported device_type: {int(dl_device_type)}" + ): + inp.__dlpack__(max_version=(1, 0), dl_device=(dl_device_type, 0)) + + @skipMeta + @onlyCPU + def test_dlpack_unsupported_dtype_error(self, device): + inp = make_tensor((5,), dtype=torch.float32, device=device).to( + torch.float8_e4m3fn + ) + + with self.assertRaisesRegex( + BufferError, ".* types are not supported by dlpack" + ): + from_dlpack(inp) + instantiate_device_type_tests(TestTorchDlPack, globals()) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 0f299cd6b6c7..af16a8f325fc 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -3211,7 +3211,7 @@ def make_non_contiguous_tensor_and_test(cnt): self.assertEqual(compiled_result, eager_result) log_stream, ctx = logs_to_string( - "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" + "torch._functorch._aot_autograd.graph_capture", "aot_graphs" ) with ctx(): make_non_contiguous_tensor_and_test(4) @@ -3246,7 +3246,7 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", torch._dynamo.decorators.mark_unbacked(x, 0) log_stream, ctx = logs_to_string( - "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" + "torch._functorch._aot_autograd.graph_capture", "aot_graphs" ) with ctx(): compiled_result = compiled_func(x, torch.tensor([10])) @@ -3305,7 +3305,7 @@ def func(x, y): torch._dynamo.decorators.mark_unbacked(x, 1) log_stream, ctx = logs_to_string( - "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" + "torch._functorch._aot_autograd.graph_capture", "aot_graphs" ) with ctx(): result_eager = func(x, torch.tensor([5, 20])) @@ -3355,7 +3355,7 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", # Pass a contiguous tensor. A recompilation will happen due to 0/1 speciialization on stride. log_stream, ctx = logs_to_string( - "torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs" + "torch._functorch._aot_autograd.graph_capture", "aot_graphs" ) with ctx(): # This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3]. diff --git a/test/test_indexing.py b/test/test_indexing.py index fa7de92b9829..3870734f60d3 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -15,6 +15,7 @@ dtypes, dtypesIfCPU, dtypesIfCUDA, + dtypesIfMPS, instantiate_device_type_tests, onlyCUDA, onlyNativeDeviceTypes, @@ -140,7 +141,10 @@ def consec(size, start=1): ) lst = [list(range(i, i + 10)) for i in range(0, 100, 10)] - tensor = torch.DoubleTensor(lst).to(device) + _make_tensor = ( + torch.DoubleTensor if not device.startswith("mps") else torch.FloatTensor + ) + tensor = _make_tensor(lst).to(device) for _i in range(100): idx1_start = random.randrange(10) idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1) @@ -156,7 +160,7 @@ def consec(size, start=1): else: lst_indexed = lst[idx1] tensor_indexed = tensor[idx1] - self.assertEqual(torch.DoubleTensor(lst_indexed), tensor_indexed) + self.assertEqual(_make_tensor(lst_indexed), tensor_indexed) self.assertRaises(ValueError, lambda: reference[1:9:0]) self.assertRaises(ValueError, lambda: reference[1:9:-1]) @@ -908,6 +912,66 @@ def test_multiple_bool_indices(self, device): mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device) self.assertEqual(v[mask1, :, mask2].shape, (3, 7)) + def test_multi_dimensional_bool_mask(self, device): + x = torch.randn(2, 2, 3, device=device) + b = ((True, False), (False, False)) + m = torch.tensor(b, dtype=torch.bool, device=device) + z = torch.tensor(0) + t = torch.tensor(True) + f = torch.tensor(False) + + # Using boolean sequence + self.assertEqual(x[b,].shape, (1, 3)) + self.assertEqual(x[b, ::2].shape, (1, 2)) + self.assertEqual(x[b, None].shape, (1, 1, 3)) + self.assertEqual(x[b, 0].shape, (1,)) + self.assertEqual(x[b, z].shape, (1,)) + self.assertEqual(x[b, True].shape, (1, 3)) + self.assertEqual(x[b, True, True, True, True].shape, (1, 3)) + self.assertEqual(x[b, False].shape, (0, 3)) + self.assertEqual(x[b, True, True, False, True].shape, (0, 3)) + self.assertEqual(x[b, t].shape, (1, 3)) + self.assertEqual(x[b, f].shape, (0, 3)) + + # Using boolean tensor + self.assertEqual(x[m].shape, (1, 3)) + self.assertEqual(x[m, ::2].shape, (1, 2)) + self.assertEqual(x[m, None].shape, (1, 1, 3)) + self.assertEqual(x[m, 0].shape, (1,)) + self.assertEqual(x[m, z].shape, (1,)) + self.assertEqual(x[m, True].shape, (1, 3)) + self.assertEqual(x[m, True, True, True, True].shape, (1, 3)) + self.assertEqual(x[m, False].shape, (0, 3)) + self.assertEqual(x[m, True, True, False, True].shape, (0, 3)) + self.assertEqual(x[m, t].shape, (1, 3)) + self.assertEqual(x[m, f].shape, (0, 3)) + + # Boolean mask in the middle of indices array + x = torch.randn(3, 2, 2, 5, device=device) + self.assertEqual(x[:, m, :].shape, (3, 1, 5)) + self.assertEqual(x[0, m, ::2].shape, (1, 3)) + self.assertEqual(x[..., m, ::2].shape, (3, 1, 3)) + self.assertEqual(x[None, ..., m, ::2].shape, (1, 3, 1, 3)) + + def test_bool_mask_assignment(self, device): + v = torch.tensor([[1, 2], [3, 4]], device=device) + mask = torch.tensor([1, 0], dtype=torch.bool, device=device) + v[mask, :] = 0 + self.assertEqual(v, torch.tensor([[0, 0], [3, 4]], device=device)) + + v = torch.tensor([[1, 2], [3, 4]], device=device) + v[:, mask] = 0 + self.assertEqual(v, torch.tensor([[0, 2], [0, 4]], device=device)) + + def test_multi_dimensional_bool_mask_assignment(self, device): + v = torch.tensor([[[[1], [2]], [[3], [4]]]], device=device) + mask = torch.tensor([[1, 0], [0, 1]], dtype=torch.bool, device=device) + v[:, mask, :] = 0 + self.assertEqual(v, torch.tensor([[[[0], [2]], [[3], [0]]]], device=device)) + v = torch.tensor([[[[1], [2]], [[3], [4]]]], device=device) + torch.ops.aten.index_put_(v, [None, mask, None], torch.tensor(0)) + self.assertEqual(v, torch.tensor([[[[0], [2]], [[3], [0]]]], device=device)) + def test_byte_mask(self, device): v = torch.randn(5, 7, 3, device=device) mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device) @@ -1243,6 +1307,7 @@ def test_int_indices(self, device): torch.float8_e5m2, torch.float8_e4m3fn, ) + @dtypesIfMPS(torch.float, torch.float16, torch.long, torch.bool) def test_index_put_src_datatype(self, device, dtype): src = torch.ones(3, 2, 4, device=device, dtype=dtype) vals = torch.ones(3, 2, 4, device=device, dtype=dtype) @@ -1989,7 +2054,9 @@ def test_truncate_leading_1s(self, device): self.assertEqual(kernel, kernel2) -instantiate_device_type_tests(TestIndexing, globals(), except_for="meta") +instantiate_device_type_tests( + TestIndexing, globals(), except_for="meta", allow_mps=True +) instantiate_device_type_tests(NumpyTests, globals(), except_for="meta") if __name__ == "__main__": diff --git a/test/test_linalg.py b/test/test_linalg.py index abbf7d6f6e9e..f49db43b4ff2 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -40,7 +40,7 @@ _get_torch_cuda_version, CDNA2OrLater, TEST_MULTIGPU from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel, \ _group_quantize_tensor_symmetric -from torch.testing._internal.common_mkldnn import bf32_on_and_off +from torch.testing._internal.common_mkldnn import reduced_f32_on_and_off from torch.distributions.binomial import Binomial import torch.backends.opt_einsum as opt_einsum import operator @@ -231,7 +231,7 @@ def _compare_untuned_tuned_entries(self, untuned_filename=None, tuned_filename=N @dtypes(torch.float, torch.cfloat) @precisionOverride({torch.float: 1e-06, torch.cfloat: 1e-06}) @tf32_on_and_off(5e-3) - @bf32_on_and_off(5e-3) + @reduced_f32_on_and_off(5e-3) def test_inner(self, device, dtype): def check(a_sizes_, b_sizes_): for a_sizes, b_sizes in ((a_sizes_, b_sizes_), (b_sizes_, a_sizes_)): @@ -785,7 +785,7 @@ def cholesky_test_helper(n, batch_dims, upper): @skipCPUIfNoLapack @dtypes(*floating_and_complex_types()) @tf32_on_and_off(0.1 if TEST_WITH_ROCM else 0.01) - @bf32_on_and_off(0.01) + @reduced_f32_on_and_off(0.01) def test_old_cholesky(self, device, dtype): from torch.testing._internal.common_utils import random_hermitian_pd_matrix @@ -4762,6 +4762,7 @@ def test_matmul_small_brute_force_3d_Nd(self, device, dtype): @onlyCUDA @skipCUDAIfNotRocm # Skipping due to SM89 OOM in CI, UT doesn't do much on NV anyways @dtypes(*floating_types_and(torch.half)) + @precisionOverride({torch.float16: 1e-1}) # TunableOp may occasionally find less precise solution def test_matmul_small_brute_force_tunableop(self, device, dtype): # disable tunableop buffer rotation for all tests everywhere, it can be slow # We set the TunableOp numerical check environment variable here because it is @@ -7199,7 +7200,7 @@ def maybe_transpose(cond, m): *[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else [])) @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) @tf32_on_and_off(0.05) - @bf32_on_and_off(0.05) + @reduced_f32_on_and_off(0.05) def test_addmm(self, device, dtype): self._test_addmm_impl(torch.addmm, None, device, dtype) @@ -7209,7 +7210,7 @@ def test_addmm(self, device, dtype): *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else [])) @dtypes(*floating_types_and(torch.bfloat16)) @tf32_on_and_off(0.05) - @bf32_on_and_off(0.05) + @reduced_f32_on_and_off(0.05) def test_addmm_relu(self, device, dtype): self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype) @@ -7221,7 +7222,7 @@ def test_addmm_relu(self, device, dtype): *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else [])) @dtypes(*floating_types_and(torch.bfloat16)) @tf32_on_and_off(0.05) - @bf32_on_and_off(0.05) + @reduced_f32_on_and_off(0.05) def test_addmm_relu_tunableop_rocm(self, device, dtype): with self._tunableop_ctx(): torch.cuda.tunable.set_rotating_buffer_size(0) @@ -7235,14 +7236,14 @@ def test_addmm_relu_tunableop_rocm(self, device, dtype): *[torch.bfloat16, torch.half] if TEST_WITH_ROCM or SM53OrLater else [])) @dtypes(*floating_types_and(torch.bfloat16)) @tf32_on_and_off(0.05) - @bf32_on_and_off(0.05) + @reduced_f32_on_and_off(0.05) def test_addmm_gelu(self, device, dtype): self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype) @dtypes(torch.float, torch.double) @dtypesIfCUDA(*floating_and_complex_types()) @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) - @bf32_on_and_off(0.005) + @reduced_f32_on_and_off(0.005) def test_addmm_sizes(self, device, dtype): for m in [0, 1, 25]: for n in [0, 1, 10]: @@ -7840,7 +7841,7 @@ def test_fp16_mv_transposed_first_argument_arm_cpu(self, device, m, k): @dtypes(torch.half, torch.float32, torch.float64, torch.int32, torch.int64, torch.cfloat, torch.cdouble) @dtypesIfCUDA(torch.float32, torch.float64, torch.cfloat, torch.cdouble) @tf32_on_and_off(0.01) - @bf32_on_and_off(0.01) + @reduced_f32_on_and_off(0.01) def test_mm(self, device, dtype): def _test_mm(n, m, p, dtype, genf): # helper function @@ -8020,7 +8021,7 @@ def test_strided_mm_bmm(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) @tf32_on_and_off(0.05) - @bf32_on_and_off(0.05) + @reduced_f32_on_and_off(0.05) def test_bmm(self, device, dtype): if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater: # cuBLAS does not guarantee BFloat16 support on SM < 53. @@ -8133,7 +8134,7 @@ def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor): @onlyNativeDeviceTypes @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) @tf32_on_and_off(0.05) - @bf32_on_and_off(0.05) + @reduced_f32_on_and_off(0.05) def test_addbmm(self, device, dtype): if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater: # cuBLAS does not guarantee BFloat16 support on SM < 53. @@ -8207,7 +8208,7 @@ def generate_tensor(): @onlyNativeDeviceTypes @dtypes(*floating_and_complex_types_and(torch.bfloat16, torch.half)) @tf32_on_and_off(0.05) - @bf32_on_and_off(0.05) + @reduced_f32_on_and_off(0.05) def test_baddbmm(self, device, dtype): if self.device_type == 'cuda' and dtype is torch.bfloat16 and not SM53OrLater: # cuBLAS does not guarantee BFloat16 support on SM < 53. @@ -9167,7 +9168,7 @@ def dims_full_for_fn(): # ROCm 6.4 passes with tf32=on, but 6.4.1 needed tolerance reduced slightly @tf32_on_and_off(0.002 if torch.version.hip else 0.001) - @bf32_on_and_off(0.001) + @reduced_f32_on_and_off(0.001) def test_broadcast_batched_matmul(self, device): n_dim = random.randint(1, 8) m_dim = random.randint(1, 8) @@ -9504,7 +9505,7 @@ def fn(torchfn, *args): fn(torch.slogdet, (0, 0))) @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) - @bf32_on_and_off(0.07) + @reduced_f32_on_and_off(0.07, 0.005) def test_tensordot(self, device): a = torch.arange(60., device=device).reshape(3, 4, 5) b = torch.arange(24., device=device).reshape(4, 3, 2) diff --git a/test/test_maskedtensor.py b/test/test_maskedtensor.py index db1ffbc38c1f..03c05c7ea6da 100644 --- a/test/test_maskedtensor.py +++ b/test/test_maskedtensor.py @@ -236,6 +236,32 @@ def test_to_sparse(self, device): _compare_mt_t(sparse_mt, data) _compare_mt_t(mt.grad, data.grad) + def test_to_device(self, device): + for sample in _generate_sample_data(device=device): + data = sample.input + mask = sample.kwargs["mask"] + mt = masked_tensor(data, mask, requires_grad=True) + + new_device = torch.device("cuda") if device != "cuda" and torch.cuda.is_available() else torch.device("cpu") + mt_device = mt.to(new_device) + + self.assertEqual(mt_device.device.type, new_device.type) + self.assertEqual(mt_device.get_mask().device.type, new_device.type) + self.assertEqual(mt_device.get_data().device.type, new_device.type) + + def test_to_dtype(self, device): + for sample in _generate_sample_data(device=device): + data = sample.input + mask = sample.kwargs["mask"] + mt = masked_tensor(data, mask, requires_grad=True) + + new_dtype = torch.float64 if data.dtype == torch.float32 else torch.float32 + mt_dtype = mt.to(new_dtype) + + self.assertEqual(mt_dtype.dtype, new_dtype) + self.assertEqual(mt_dtype.get_mask().dtype, torch.bool) + self.assertEqual(mt_dtype.get_data().dtype, new_dtype) + def test_to_dense(self, device): samples = _generate_sample_data( device=device, diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 33127c689e20..943c0ae1f550 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -496,7 +496,8 @@ def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major): @parametrize("op", ["2d/2d", "2d/3d", "3d/2d", "3d/3d"]) @parametrize("a_row_major", [False, True]) @parametrize("b_row_major", [False, True]) - def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major): + @parametrize("max_autotune", [False, True]) + def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune): torch._dynamo.reset() device = "cuda" @@ -506,12 +507,18 @@ def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major): align = 16 // dtype_AB.itemsize f_ref = torch._grouped_mm + + options = {} + if max_autotune: + options.update( + { + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + } + ) f = torch.compile( f_ref, - options={ - "max_autotune": True, - "max_autotune_gemm_backends": "TRITON", - }, + options=options, ) if op == "2d/2d": @@ -1417,8 +1424,16 @@ def test_honor_sm_carveout(self) -> None: ] self.assertEqual(no_carveout, no_carveout_again) - self.assertNotEqual(no_carveout, carveout_66) - self.assertNotEqual(carveout_66, carveout_0) + capability = torch.cuda.get_device_capability() + if capability == (10, 0): + # expected failure + # CUTLASS only supports SM carveout via green contexts on SM100 + self.assertEqual(no_carveout, carveout_66) + self.assertEqual(carveout_66, carveout_0) + else: + # correct behavior + self.assertNotEqual(no_carveout, carveout_66) + self.assertNotEqual(carveout_66, carveout_0) def test_pack_uint4(self): """ diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py index 23788653cc6c..e2ec92fc8dad 100644 --- a/test/test_mkldnn.py +++ b/test/test_mkldnn.py @@ -27,7 +27,7 @@ instantiate_device_type_tests, dtypes, ) -from torch.testing._internal.common_mkldnn import bf32_on_and_off +from torch.testing._internal.common_mkldnn import reduced_f32_on_and_off # batched grad doesn't support mkldnn gradcheck = functools.partial(gradcheck, check_batched_grad=False) @@ -284,15 +284,15 @@ def _test_conv_base(self, dim): if bias: self.assertEqual(conv.bias.grad, mkldnn_conv.bias.grad) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv1d(self): self._test_conv_base(dim=1) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv2d(self): self._test_conv_base(dim=2) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv3d(self): self._test_conv_base(dim=3) @@ -407,7 +407,7 @@ def _test_conv_deconv_nhwc_base(self, conv_module, weight_memory_format, dtype, self.assertEqual(conv1.bias.grad, conv2.bias.grad, atol=prec, rtol=prec) self.assertEqual(x1.grad, x2.grad, atol=prec, rtol=prec) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv_nhwc_fp32(self): self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.float32) self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.float32) @@ -443,7 +443,7 @@ def test_conv_nhwc_lower_precision(self, dtype): self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype, prec=prec) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv_transpose_nhwc_fp32(self): self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.float32) self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.float32) @@ -532,15 +532,15 @@ def _test_conv_transpose_base(self, dim): if bias: self.assertEqual(conv.bias.grad, conv_ref.bias.grad) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv_transpose1d(self): self._test_conv_transpose_base(dim=1) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv_transpose2d(self): self._test_conv_transpose_base(dim=2) - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_conv_transpose3d(self): self._test_conv_transpose_base(dim=3) @@ -1680,21 +1680,29 @@ def test_mlkdnn_get_set(self): # get/set mkldnn ops with torch.backends.mkldnn.flags(enabled=None, fp32_precision="bf16"): self.assertEqual(torch.backends.mkldnn.fp32_precision, "bf16") + with torch.backends.mkldnn.flags(enabled=None, fp32_precision="tf32"): + self.assertEqual(torch.backends.mkldnn.fp32_precision, "tf32") with torch.backends.mkldnn.flags(enabled=None, fp32_precision="none"): self.assertEqual(torch.backends.mkldnn.fp32_precision, "none") # get/set matmul torch.backends.mkldnn.matmul.fp32_precision = "bf16" self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "bf16") + torch.backends.mkldnn.matmul.fp32_precision = "tf32" + self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "tf32") torch.backends.mkldnn.matmul.fp32_precision = "none" self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "none") # get/set conv torch.backends.mkldnn.conv.fp32_precision = "bf16" self.assertEqual(torch.backends.mkldnn.conv.fp32_precision, "bf16") + torch.backends.mkldnn.conv.fp32_precision = "tf32" + self.assertEqual(torch.backends.mkldnn.conv.fp32_precision, "tf32") torch.backends.mkldnn.conv.fp32_precision = "none" self.assertEqual(torch.backends.mkldnn.conv.fp32_precision, "none") # get/set rnn torch.backends.mkldnn.rnn.fp32_precision = "bf16" self.assertEqual(torch.backends.mkldnn.rnn.fp32_precision, "bf16") + torch.backends.mkldnn.rnn.fp32_precision = "tf32" + self.assertEqual(torch.backends.mkldnn.rnn.fp32_precision, "tf32") torch.backends.mkldnn.rnn.fp32_precision = "none" self.assertEqual(torch.backends.mkldnn.rnn.fp32_precision, "none") @@ -1710,18 +1718,14 @@ def test_default_use_parent(self): torch.backends.mkldnn.matmul.fp32_precision = "none" with torch.backends.mkldnn.flags(enabled=None, fp32_precision="bf16"): self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "bf16") + with torch.backends.mkldnn.flags(enabled=None, fp32_precision="tf32"): + self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "tf32") with torch.backends.mkldnn.flags(enabled=None, fp32_precision="none"): with torch.backends.flags(fp32_precision="bf16"): self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "bf16") with torch.backends.flags(fp32_precision="tf32"): - # when parent is a not supported precision, use default - self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "none") + self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "tf32") - @recover_orig_fp32_precision - def test_invalid(self): - # use default if user set a not supported precision - torch.backends.mkldnn.matmul.fp32_precision = "tf32" - self.assertEqual(torch.backends.mkldnn.matmul.fp32_precision, "none") instantiate_device_type_tests(TestMkldnn, globals(), only_for=('cpu',)) diff --git a/test/test_mps.py b/test/test_mps.py index 89e68b171852..ea1013c97213 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -8013,6 +8013,18 @@ def test_64bit_index_select(self): gc.collect() torch.mps.empty_cache() + @serialTest() + def test_rand_2b_raises(self): + if MACOS_VERSION < 14.0: + raise unittest.SkipTest("Crashes on MacOS-13") + int32_max = torch.iinfo(torch.int32).max + with self.assertRaises(RuntimeError): + # This used to crash with NDArray dimension length > INT_MAX + x = torch.randint(0, 10, (int32_max + 1,), dtype=torch.int8, device='mps') + x = torch.randint(0, 10, (int32_max,), dtype=torch.int8, device='mps') + self.assertEqual(x.numel(), int32_max) + del x + class TestLogical(TestCaseMPS): def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False): @@ -9257,6 +9269,18 @@ def test_sdpa_mask_fp16_L6(self): def test_sdpa_mask_fp16_L6_S17_NH23_HS121(self): self._test_sdpa_mask(torch.float16, 7, 17, 23, 121) + # Regression test from: https://github.com/pytorch/pytorch/issues/156707 + @parametrize("dtype", [torch.float16, torch.float32]) + def test_sdpa_full_mask(self, dtype): + q = torch.randn(1, 1, 2, 4, dtype=dtype) + k = torch.randn(1, 1, 2, 4, dtype=dtype) + v = torch.randn(1, 1, 2, 4, dtype=dtype) + mask = torch.tensor([[[[False, False], [True, True]]]], dtype=torch.bool) + + out_cpu = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) + out_mps = F.scaled_dot_product_attention(q.to('mps'), k.to('mps'), v.to('mps'), attn_mask=mask.to('mps')) + self._compare_tensors(out_mps.cpu(), out_cpu) + @parametrize("dtype", [torch.float16, torch.float32]) def test_sdpa_3d_input(self, dtype): head_num, seq_len, embed_dim = 16, 16, 80 diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 6dbb0b2cdad8..0e0234b08941 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -6751,10 +6751,11 @@ def check_forward_backward(skip_backward=False): and check_cudnn and (dtype == torch.float16 or dtype == torch.bfloat16) ): - with torch.nn.attention.sdpa_kernel( - torch.nn.attention.SDPBackend.CUDNN_ATTENTION - ): - check_forward_backward() + with self.assertRaisesRegex(RuntimeError, "cuDNN SDPA Nested Tensor"): + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.CUDNN_ATTENTION + ): + check_forward_backward() @skipIfTorchDynamo("SDPA test compiles internally") @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") diff --git a/test/test_nn.py b/test/test_nn.py index 0323080728b3..218a65f388f0 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -55,7 +55,7 @@ from torch.testing._internal.common_utils import dtype2prec_DONTUSE from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_off, tf32_on from torch.types import _TensorOrTensors -from torch.testing._internal.common_mkldnn import bf32_on_and_off +from torch.testing._internal.common_mkldnn import reduced_f32_on_and_off AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported() @@ -8278,7 +8278,7 @@ def _test_module_empty_inputs(self, module, inputs): "Scipy v1.0 and/or numpy not found") @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off() - @bf32_on_and_off() + @reduced_f32_on_and_off() def test_affine_2d_rotate0(self, device): # scipy before 1.0.0 do not support homogeneous coordinate # scipy.ndimage.affine_transform, so we need to skip. @@ -8319,7 +8319,7 @@ def test_affine_2d_rotate0(self, device): "Scipy v1.0 and/or numpy not found") @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off(0.01 if TEST_WITH_ROCM else 0.001) - @bf32_on_and_off(0.001) + @reduced_f32_on_and_off(0.001) def test_affine_2d_rotate90(self, device): # scipy before 1.0.0 do not support homogeneous coordinate # scipy.ndimage.affine_transform, so we need to skip. @@ -8369,7 +8369,7 @@ def test_affine_2d_rotate90(self, device): "Scipy v1.0 and/or numpy not found") @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off(0.005) - @bf32_on_and_off(0.005) + @reduced_f32_on_and_off(0.005) def test_affine_2d_rotate45(self, device): # scipy before 1.0.0 do not support homogeneous coordinate # scipy.ndimage.affine_transform, so we need to skip. @@ -8447,7 +8447,7 @@ def test_avg_pool_large_tensor2(self, device): "Scipy v1.0 and/or numpy not found") @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) - @bf32_on_and_off(0.005) + @reduced_f32_on_and_off(0.005) def test_affine_2d_rotateRandom(self, device): # scipy before 1.0.0 do not support homogeneous coordinate # scipy.ndimage.affine_transform, so we need to skip. @@ -8500,7 +8500,7 @@ def test_affine_2d_rotateRandom(self, device): "Scipy v1.0 and/or numpy not found") @expectedFailureMPS # aten::grid_sampler_3d not implemented https://github.com/pytorch/pytorch/issues/77764 @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) - @bf32_on_and_off(0.005) + @reduced_f32_on_and_off(0.005) def test_affine_3d_rotateRandom(self, device): # scipy before 1.0.0 do not support homogeneous coordinate # scipy.ndimage.affine_transform, so we need to skip. diff --git a/test/test_openreg.py b/test/test_openreg.py index 1fab8c4261c7..dc52231ff7bf 100644 --- a/test/test_openreg.py +++ b/test/test_openreg.py @@ -10,7 +10,7 @@ import numpy as np import psutil -import pytorch_openreg # noqa: F401 +import torch_openreg # noqa: F401 import torch from torch.serialization import safe_globals @@ -285,7 +285,6 @@ def test_manual_seed(self): self.assertEqual(torch.openreg.initial_seed(), 2024) # type: ignore[misc] # Autograd - @unittest.skipIf(not IS_LINUX, "Only works on linux") def test_autograd_init(self): # Make sure autograd is initialized torch.ones(2, requires_grad=True, device="openreg").sum().backward() @@ -584,4 +583,5 @@ def test_open_device_dlpack(self): if __name__ == "__main__": - run_tests() + if IS_LINUX: + run_tests() diff --git a/test/test_ops.py b/test/test_ops.py index 26f8865b3a00..201b0323a86f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1109,7 +1109,22 @@ def _case_four_transform(t): if op.is_factory_function and sample.kwargs.get("dtype", None) is None: op_out(out=out) else: - with self.assertRaises(RuntimeError, msg=msg_fail): + # TODO: Remove me when all ops will raise type error on mismatched types + exc_type = ( + TypeError + if op.name + in [ + "_chunk_cat", + "cat", + "column_stack", + "dstack", + "hstack", + "vstack", + "stack", + ] + else RuntimeError + ) + with self.assertRaises(exc_type, msg=msg_fail): op_out(out=out) @ops( diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index aef4cb0e6917..e0480ba6a684 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -2480,6 +2480,39 @@ def __torch_dispatch__(self, func, types, args, kwargs=None): self.assertEqual(res, t.a) self.assertIs(type(res), torch.Tensor) + def test_custom_dispatch_mode_supports_higher_order_operators(self): + class Mode(TorchDispatchMode): + supports_higher_order_operators = True + + def __torch_dispatch__(self, func, types, args=..., kwargs=None): + if func is torch.ops.higher_order.cond: + return torch.ones(3, 3) + return NotImplemented + + pred = torch.tensor(True) + x = torch.randn(1, 1) + with Mode(): + out = torch.cond(pred, lambda x: x.sin(), lambda x: x.cos(), (x,)) + self.assertEqual(out, torch.ones(3, 3)) + + def test_custom_dispatch_mode_not_supports_higher_order_operators(self): + class Mode(TorchDispatchMode): + supports_higher_order_operators = False + + def __torch_dispatch__(self, func, types, args=..., kwargs=None): + if func is torch.ops.higher_order.cond: + return torch.ones(3, 3) + return NotImplemented + + pred = torch.tensor(True) + x = torch.randn(1, 1) + with self.assertRaisesRegex( + NotImplementedError, + "There was no rule registered for HigherOrderOperator cond and mode", + ): + with Mode(): + torch.cond(pred, lambda x: x.sin(), lambda x: x.cos(), (x,)) + class TestPythonDispatcher(TestCase): def test_basic(self): diff --git a/test/test_rename_privateuse1_to_existing_device.py b/test/test_rename_privateuse1_to_existing_device.py new file mode 100644 index 000000000000..539412a32238 --- /dev/null +++ b/test/test_rename_privateuse1_to_existing_device.py @@ -0,0 +1,59 @@ +# Owner(s): ["module: PrivateUse1"] + +import torch +from torch.testing._internal.common_utils import run_tests, TestCase + + +class DummyPrivateUse1Module: + @staticmethod + def is_available(): + return True + + @staticmethod + def is_autocast_enabled(): + return True + + @staticmethod + def get_autocast_dtype(): + return torch.float16 + + @staticmethod + def set_autocast_enabled(enable): + pass + + @staticmethod + def set_autocast_dtype(dtype): + pass + + @staticmethod + def get_amp_supported_dtype(): + return [torch.float16] + + +class TestRenamePrivateuseoneToExistingBackend(TestCase): + def test_external_module_register_with_existing_backend(self): + torch.utils.rename_privateuse1_backend("maia") + with self.assertRaisesRegex(RuntimeError, "has already been set"): + torch.utils.rename_privateuse1_backend("dummmy") + + custom_backend_name = torch._C._get_privateuse1_backend_name() + self.assertEqual(custom_backend_name, "maia") + + with self.assertRaises(AttributeError): + torch.maia.is_available() + + with self.assertRaisesRegex(AssertionError, "Tried to use AMP with the"): + with torch.autocast(device_type=custom_backend_name): + pass + torch._register_device_module("maia", DummyPrivateUse1Module) + + torch.maia.is_available() # type: ignore[attr-defined] + with torch.autocast(device_type=custom_backend_name): + pass + + self.assertEqual(torch._utils._get_device_index("maia:1"), 1) + self.assertEqual(torch._utils._get_device_index(torch.device("maia:2")), 2) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_torch.py b/test/test_torch.py index ba830489a99b..7af57f23b8fe 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -59,7 +59,7 @@ from torch.testing._internal.common_cuda import ( tf32_on_and_off, TEST_CUDNN, TEST_MULTIGPU, _create_scaling_case, _create_scaling_models_optimizers) -from torch.testing._internal.common_mkldnn import bf32_on_and_off +from torch.testing._internal.common_mkldnn import reduced_f32_on_and_off from torch.testing._internal.common_dtype import ( floating_types_and, get_all_math_dtypes, all_types_and_complex_and, complex_types, all_types_and, floating_types, floating_and_complex_types, integral_types_and, @@ -2557,7 +2557,7 @@ def test_cdist_cuda_backward(self, device): self.assertEqual(y1.grad, y2.grad, rtol=0, atol=0.001) @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) - @bf32_on_and_off(0.08) + @reduced_f32_on_and_off(0.08) def test_cdist_large(self, device): for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: x = torch.randn(1000, 10, device=device) @@ -2568,7 +2568,7 @@ def test_cdist_large(self, device): @slowTest @tf32_on_and_off(0.01) - @bf32_on_and_off(0.08) + @reduced_f32_on_and_off(0.08) def test_cdist_large_batch(self, device): for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: x = torch.randn(4, 3, 1000, 10, device=device) @@ -2578,7 +2578,7 @@ def test_cdist_large_batch(self, device): self.assertEqual(expected, actual) @tf32_on_and_off(0.005) - @bf32_on_and_off(0.04) + @reduced_f32_on_and_off(0.04) def test_cdist_non_contiguous(self, device): for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: x = torch.randn(5, 7, device=device).mT @@ -2606,7 +2606,7 @@ def test_cdist_non_contiguous(self, device): self.assertEqual(expected, actual) @tf32_on_and_off(0.005) - @bf32_on_and_off(0.04) + @reduced_f32_on_and_off(0.04) def test_cdist_non_contiguous_batch(self, device): for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: x = torch.randn(4, 3, 2, 5, 7, device=device).mT diff --git a/test/test_transformers.py b/test/test_transformers.py index 85c9f4a07cec..89db8d798c26 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -49,6 +49,7 @@ PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, PLATFORM_SUPPORTS_FUSED_ATTENTION, PLATFORM_SUPPORTS_CUDNN_ATTENTION, + SM90OrLater, tf32_on_and_off, tf32_enabled, ) @@ -1617,6 +1618,34 @@ def test_invalid_last_dim_stride(self, device, kernel: SDPBackend): self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, False)) + @onlyCUDA + @unittest.skipIf( + not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION + or not PLATFORM_SUPPORTS_CUDNN_ATTENTION, + "Efficient or cuDNN Attention was not built for this system", + ) + @parametrize("kernel", [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.CUDNN_ATTENTION]) + def test_mask_invalid_last_dim_stride(self, device, kernel): + with sdpa_kernel(backends=[kernel]): + dtype = torch.float16 + make_tensor = partial(torch.rand, device=device, dtype=dtype) + size = SdpaShape(2, 2, 8, 8) + q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) + attn_mask = make_tensor((2, 2, 8, 8)) + # Passing in a attn_mask with last dim stride not equal to 1 will error + attn_mask.as_strided_(size, [2, 2, 2, 2]) + + with self.assertWarnsRegex( + UserWarning, + "GPU backends require attn_mask's last dimension to have stride 1 while the CPU does not", + ): + self.assertRaises( + RuntimeError, + lambda: torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask, 0.0, False + ), + ) + @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA or pre-SM80 hardware") @parametrize("fused_kernel", [SDPBackend.EFFICIENT_ATTENTION]) @@ -2628,7 +2657,6 @@ def test_cudnn_attention_gqa(self, device): @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") - @unittest.expectedFailure # cuDNN currently doesn't support this on SM100+/fails graph validation def test_cudnn_attention_d256_heuristic(self, device): dtype = torch.bfloat16 make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) @@ -2639,7 +2667,7 @@ def test_cudnn_attention_d256_heuristic(self, device): v_shape = SdpaShape(batch, num_heads, seq_len, head_dim_v) query, key, value = make_tensor(q_shape), make_tensor(k_shape), make_tensor(v_shape) - with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION], set_priority=True): + with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH], set_priority=True): actual = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) actual.backward(torch.randn_like(actual)) @@ -2677,7 +2705,7 @@ def test_fused_attention_different_dk_dv(self, device): @skipIfRocm # No cuDNN Attention - @unittest.skipIf(True, "broken as of cuDNN 9.10") + @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_cudnn_attention_fail_d128(self, device): # Test that cuDNN attention dispatching correctly bails out on d > 128 b, h = 1, 2 @@ -2692,6 +2720,7 @@ def test_cudnn_attention_fail_d128(self, device): ISSM90 = device_cap == (9, 0) ISSM100 = device_cap == (10, 0) with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): + # SM90/100 support d <= 256 as of cuDNN 9.5.1+ if (ISSM90 or ISSM100) and torch.backends.cudnn.version() >= 90501: torch.nn.functional.scaled_dot_product_attention(q, k, v) else: @@ -3127,19 +3156,15 @@ def test_fused_sdp_choice(self, device, type: str): value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) - device_capability = None - if "cuda" in str(device): - device_capability = torch.cuda.get_device_capability() - prefer_cudnn = "TORCH_CUDNN_SDPA_PREFERRED" in os.environ - prefer_cudnn = prefer_cudnn and device_capability and (device_capability == (9, 0) or device_capability == (10, 0)) - # TODO we are currently disabling this by default, lets assert that this returns # FlashAttention, we need to change when we make remove opt-in for cudnn - if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and prefer_cudnn: - self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) + if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater: + self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value) + with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): + self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) elif PLATFORM_SUPPORTS_FLASH_ATTENTION: self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value) - elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and not prefer_cudnn: # e.g., we're on Windows + elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION: # e.g., we're on Windows self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.EFFICIENT_ATTENTION.value) with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) diff --git a/test/test_transformers_privateuse1.py b/test/test_transformers_privateuse1.py index 728b0a118825..0aa15260d094 100644 --- a/test/test_transformers_privateuse1.py +++ b/test/test_transformers_privateuse1.py @@ -4,7 +4,7 @@ from collections import namedtuple from functools import partial -import pytorch_openreg # noqa: F401 +import torch_openreg # noqa: F401 import torch from torch.nn.attention import SDPBackend diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index 88fcdd3a5dca..59d856ec4fc9 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -1046,7 +1046,7 @@ def test_cat_out_different_dtypes(self, device): and not (out_dtype.is_floating_point or out_dtype.is_complex)) or ((x_dtype.is_complex or y_dtype.is_complex) and not out_dtype.is_complex)): # This combinations do not support type conversion to a different class out type - with self.assertRaises(RuntimeError): + with self.assertRaises(TypeError): torch.cat([x, y], out=out) else: torch.cat([x, y], out=out) diff --git a/test/torch_np/numpy_tests/linalg/test_linalg.py b/test/torch_np/numpy_tests/linalg/test_linalg.py index afda92e5b6b9..869c4af75391 100644 --- a/test/torch_np/numpy_tests/linalg/test_linalg.py +++ b/test/torch_np/numpy_tests/linalg/test_linalg.py @@ -489,7 +489,7 @@ class SolveCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): # kept apart from TestSolve for use for testing with matrices. def do(self, a, b, tags): x = linalg.solve(a, b) - assert_almost_equal(b, dot_generalized(a, x)) + assert_almost_equal(b, dot_generalized(a, x), single_decimal=5) assert_(consistent_subclass(x, b)) diff --git a/test/torch_np/test_indexing.py b/test/torch_np/test_indexing.py index eac68246bd5a..084cc2f73e8b 100644 --- a/test/torch_np/test_indexing.py +++ b/test/torch_np/test_indexing.py @@ -412,6 +412,77 @@ def test_special_index_types(self): self._test_cases(cases + numpy_torch_cases, "Special index types") + def test_ellipsis(self): + """Tests containing ellipsis.""" + cases = [ + # Ellipsis + Basic indexing + { + "shape": (3, 4, 5), + "index": (slice(None), 0, ..., slice(None)), + "name": "empty ellipsis without advanced indexing", + }, + { + "shape": (3, 4, 5), + "index": (slice(None), ..., 0), + "name": "non-empty ellipsis without advanced indexing", + }, + # Ellipsis + Advanced indexing without separation + { + "shape": (3, 4, 5), + "index": (slice(None), ..., slice(None), (0, 1)), + "name": "empty ellipsis without separation", + }, + { + "shape": (3, 4, 5), + "index": (slice(None), ..., (0, 1)), + "name": "non-empty ellipsis without separation", + }, + # Ellipsis + Advanced indexing with separation + { + "shape": (3, 4, 5), + "index": (slice(None), (0, 1), ..., (0, 1)), + "name": "empty ellipsis separation", + }, + { + "shape": (1, 3, 4, 5), + "index": (slice(None), (0, 1), ..., (0, 1)), + "name": "non-empty ellipsis separation", + }, + { + "shape": (4, 3, 5), + "index": (slice(None), ((0,), (1,)), ..., (0, 1)), + "name": "empty ellipsis separation with 2-depth int sequence", + }, + { + "shape": (4, 3, 5, 6), + "index": (slice(None), ((0,), (1,)), ..., (0, 1), slice(None)), + "name": "empty ellipsis separation with 2-depth int sequence and end slice", + }, + { + "shape": (4, 3, 5, 6), + "index": (slice(None), ((0,), (1,)), ..., (0, 1), (((0, 1), (1, 2)),)), + "name": "empty ellipsis separation with 2 and 3-depth int sequence", + }, + # Ellipsis + Boolean masks in advanced indexing with separation + { + "shape": (3, 4, 5), + "index": (slice(None), True, True, True, ..., 0, 0), + "name": "empty ellipsis separation with 0-dim boolean masks", + }, + { + "shape": (4, 3, 5), + "index": (slice(None), (True, True, False), ..., (0, 1)), + "name": "empty ellipsis separation with 1-dim boolean masks", + }, + # TODO(manuelcandales) Fix issue #71673 and enable this case + # { + # "shape": (1, 2, 2, 4, 5), + # "index": (slice(None), ((True, False), (True, True)), (0, 1, 2), ..., (0,)), + # "name": "empty ellipsis separation with 2-dim boolean masks", + # }, + ] + self._test_cases(cases, "Ellipsis and advanced indexing separation") + if __name__ == "__main__": run_tests() diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 7af47591bd08..f0349c2484b6 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1267,6 +1267,11 @@ mean: not_implemented("native_layer_norm_backward mean") rstd: not_implemented("native_layer_norm_backward rstd") +- name: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor) + input, weight: "GradMode::is_enabled() || grads[1].defined() ? infinitely_differentiable_native_rms_norm_backward(grads[0], grads[1], input, normalized_shape, result1, weight, grad_input_mask) : (grads[0].defined() ? _fused_rms_norm_backward(grads[0], input, normalized_shape, result1, weight, grad_input_mask) : std::tuple())" + result0: rms_norm_jvp(input_p, input_t, weight_p, weight_t, result1, normalized_shape) + result1: rms_norm_rstd_jvp(input_p, input_t, result1, normalized_shape) + - name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(grads[0].device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), input.device().is_xpu() ? input : input.contiguous(input.device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple())" result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group) @@ -2896,10 +2901,6 @@ output_differentiability: [True, False, False, False, False, False] query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale) -- name: _cudnn_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) - output_differentiability: [True, False, False, False, False, False, False, False, False] - query, key, value: _cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) - - name: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) output_differentiability: [True, False, False, False, False, False, False, False, False] query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale) diff --git a/tools/bazel.bzl b/tools/bazel.bzl index cd263ba4d324..9b662859adb4 100644 --- a/tools/bazel.bzl +++ b/tools/bazel.bzl @@ -2,7 +2,7 @@ load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test") load("@rules_cuda//cuda:defs.bzl", "cuda_library", "requires_cuda_enabled") load("@rules_python//python:defs.bzl", "py_binary", "py_library") load("@pip_deps//:requirements.bzl", "requirement") -load("@pytorch//c10/macros:cmake_configure_file.bzl", "cmake_configure_file") +load("@pytorch//torch/headeronly/macros:cmake_configure_file.bzl", "cmake_configure_file") load("@pytorch//tools/config:defs.bzl", "if_cuda") def _genrule(**kwds): diff --git a/tools/build/bazel/requirements.in b/tools/build/bazel/requirements.in index 37750163da81..7498f9065f0c 100644 --- a/tools/build/bazel/requirements.in +++ b/tools/build/bazel/requirements.in @@ -1,6 +1,6 @@ PyYAML==6.0.1 numpy==1.26.4 -requests==2.32.2 +requests==2.32.4 setuptools==78.1.1 sympy==1.12 typing_extensions==4.11.0 diff --git a/tools/build/bazel/requirements.txt b/tools/build/bazel/requirements.txt index a15924660167..dab9792ceae3 100644 --- a/tools/build/bazel/requirements.txt +++ b/tools/build/bazel/requirements.txt @@ -203,9 +203,9 @@ pyyaml==6.0.1 \ --hash=sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d \ --hash=sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f # via -r requirements.in -requests==2.32.2 \ - --hash=sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289 \ - --hash=sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c +requests==2.32.4 \ + --hash=sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c \ + --hash=sha256:27d0316682c8a29834d3264820024b62a36942083d52caf2f14c0591336d3422 # via -r requirements.in sympy==1.12 \ --hash=sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5 \ diff --git a/tools/linter/adapters/pip_init.py b/tools/linter/adapters/pip_init.py index fbf9808e9b26..137e4637bdb4 100644 --- a/tools/linter/adapters/pip_init.py +++ b/tools/linter/adapters/pip_init.py @@ -13,17 +13,20 @@ import time -def run_command(args: list[str]) -> subprocess.CompletedProcess[bytes]: +def run_command( + args: list[str], + env: dict[str, str] | None = None, +) -> subprocess.CompletedProcess[str]: logging.debug("$ %s", " ".join(args)) start_time = time.monotonic() try: - return subprocess.run(args, check=True) + return subprocess.run(args, env=env, text=True, encoding="utf-8", check=True) finally: end_time = time.monotonic() logging.debug("took %dms", (end_time - start_time) * 1000) -if __name__ == "__main__": +def main() -> None: parser = argparse.ArgumentParser(description="pip initializer") parser.add_argument( "packages", @@ -52,17 +55,16 @@ def run_command(args: list[str]) -> subprocess.CompletedProcess[bytes]: stream=sys.stderr, ) - uv_available = ( - any(prefix in sys.base_prefix for prefix in ["uv/python", "uv\\python"]) - and shutil.which("uv") is not None - ) - - if uv_available: - pip_args = ["uv", "pip", "install"] - elif sys.executable: - pip_args = [sys.executable, "-mpip", "install"] - else: - pip_args = ["pip3", "install"] + env: dict[str, str] = { + **os.environ, + "UV_PYTHON": sys.executable, + "UV_PYTHON_DOWNLOADS": "never", + "FORCE_COLOR": "1", + "CLICOLOR_FORCE": "1", + } + uv_index = env.get("UV_INDEX", env.get("PIP_EXTRA_INDEX_URL")) + if uv_index: + env["UV_INDEX"] = uv_index # If we are in a global install, use `--user` to install so that you do not # need root access in order to initialize linters. @@ -70,9 +72,20 @@ def run_command(args: list[str]) -> subprocess.CompletedProcess[bytes]: # However, `pip install --user` interacts poorly with virtualenvs (see: # https://bit.ly/3vD4kvl) and conda (see: https://bit.ly/3KG7ZfU). So in # these cases perform a regular installation. - in_conda = os.environ.get("CONDA_PREFIX") is not None - in_virtualenv = os.environ.get("VIRTUAL_ENV") is not None - if not in_conda and not in_virtualenv: + in_conda = env.get("CONDA_PREFIX") is not None + in_virtualenv = env.get("VIRTUAL_ENV") is not None + need_user_flag = not in_conda and not in_virtualenv + + uv: str | None = shutil.which("uv") + is_uv_managed_python = "uv/python" in sys.base_prefix.replace("\\", "/") + if uv and (is_uv_managed_python or not need_user_flag): + pip_args = [uv, "pip", "install"] + elif sys.executable: + pip_args = [sys.executable, "-mpip", "install"] + else: + pip_args = ["pip3", "install"] + + if need_user_flag: pip_args.append("--user") pip_args.extend(args.packages) @@ -92,4 +105,8 @@ def run_command(args: list[str]) -> subprocess.CompletedProcess[bytes]: print(f"Would have run: {pip_args}") sys.exit(0) - run_command(pip_args) + run_command(pip_args, env=env) + + +if __name__ == "__main__": + main() diff --git a/tools/linter/adapters/pyproject_linter.py b/tools/linter/adapters/pyproject_linter.py index e3a61cc63423..5e046509319f 100644 --- a/tools/linter/adapters/pyproject_linter.py +++ b/tools/linter/adapters/pyproject_linter.py @@ -128,16 +128,16 @@ def check_file(filename: str) -> list[LintMessage]: ), ) ] - if f"{python_major}.{large_minor}" in supported_python_versions: - return [ - format_error_message( - filename, - message=( - "'project.requires-python' must specify a maximum version, " - f"but found {requires_python!r}." - ), - ) - ] + # if f"{python_major}.{large_minor}" in supported_python_versions: + # return [ + # format_error_message( + # filename, + # message=( + # "'project.requires-python' must specify a maximum version, " + # f"but found {requires_python!r}." + # ), + # ) + # ] classifiers = project.get("classifiers") if not ( @@ -158,49 +158,49 @@ def check_file(filename: str) -> list[LintMessage]: ) ] - python_version_classifiers = [ - c - for c in classifiers - if ( - c.startswith("Programming Language :: Python :: ") - and not c.endswith((f":: {python_major}", f":: {python_major} :: Only")) - ) - ] - if python_version_classifiers: - python_version_classifier_set = set(python_version_classifiers) - supported_python_version_classifier_set = { - f"Programming Language :: Python :: {v}" - for v in supported_python_versions - } - if python_version_classifier_set != supported_python_version_classifier_set: - missing_classifiers = sorted( - supported_python_version_classifier_set - - python_version_classifier_set - ) - extra_classifiers = sorted( - python_version_classifier_set - - supported_python_version_classifier_set - ) - if missing_classifiers: - return [ - format_error_message( - filename, - message=( - "'project.classifiers' is missing the following classifier(s):\n" - + "\n".join(f" {c!r}" for c in missing_classifiers) - ), - ) - ] - if extra_classifiers: - return [ - format_error_message( - filename, - message=( - "'project.classifiers' contains extra classifier(s):\n" - + "\n".join(f" {c!r}" for c in extra_classifiers) - ), - ) - ] + # python_version_classifiers = [ + # c + # for c in classifiers + # if ( + # c.startswith("Programming Language :: Python :: ") + # and not c.endswith((f":: {python_major}", f":: {python_major} :: Only")) + # ) + # ] + # if python_version_classifiers: + # python_version_classifier_set = set(python_version_classifiers) + # supported_python_version_classifier_set = { + # f"Programming Language :: Python :: {v}" + # for v in supported_python_versions + # } + # if python_version_classifier_set != supported_python_version_classifier_set: + # missing_classifiers = sorted( + # supported_python_version_classifier_set + # - python_version_classifier_set + # ) + # extra_classifiers = sorted( + # python_version_classifier_set + # - supported_python_version_classifier_set + # ) + # if missing_classifiers: + # return [ + # format_error_message( + # filename, + # message=( + # "'project.classifiers' is missing the following classifier(s):\n" + # + "\n".join(f" {c!r}" for c in missing_classifiers) + # ), + # ) + # ] + # if extra_classifiers: + # return [ + # format_error_message( + # filename, + # message=( + # "'project.classifiers' contains extra classifier(s):\n" + # + "\n".join(f" {c!r}" for c in extra_classifiers) + # ), + # ) + # ] return [] diff --git a/tools/linter/dictionary.txt b/tools/linter/dictionary.txt index 49ae353c7d02..706881a8f10f 100644 --- a/tools/linter/dictionary.txt +++ b/tools/linter/dictionary.txt @@ -1,23 +1,35 @@ +aLoad +aLoads ans +aStore +aStores belows +bLoad +bLoads +bStore +bStores BU contiguities contiguity coo -Din -Dout -dOut +deser +din +dout ElementE followings fro froms Halfs hsa +indexT +inH inp inps inpt inpts matA +matB +matC nd nin NotIn @@ -30,6 +42,7 @@ ot overrideable oW padD +posIn ptd rebuild rebuilt @@ -45,4 +58,6 @@ strat supercede supercedes te +THW +tne WONT diff --git a/tools/nightly.py b/tools/nightly.py index 8409173e8b5b..0ed8cfe165aa 100755 --- a/tools/nightly.py +++ b/tools/nightly.py @@ -250,6 +250,7 @@ def __init__( self._env = { "PIP_EXTRA_INDEX_URL": self.pip_source.index_url, "UV_INDEX": self.pip_source.index_url, + "UV_PYTHON_DOWNLOADS": "never", "FORCE_COLOR": "1", "CLICOLOR_FORCE": "1", } @@ -475,13 +476,12 @@ def uv( cmd = [str(self.bindir / "uv"), *args] env = popen_kwargs.pop("env", None) or {} check = popen_kwargs.pop("check", True) - env["UV_PYTHON"] = str(python) return subprocess.run( cmd, check=check, text=True, encoding="utf-8", - env={**self._env, **env}, + env={**self._env, **env, "UV_PYTHON": str(python)}, **popen_kwargs, ) diff --git a/tools/stats/monitor.py b/tools/stats/monitor.py index a5affc2510b7..38d1f94b178b 100644 --- a/tools/stats/monitor.py +++ b/tools/stats/monitor.py @@ -78,6 +78,9 @@ class GpuData: uuid: str utilization: float mem_utilization: float + allocated_mem: float + allocated_mem_value: float + total_mem_value: float try: @@ -259,6 +262,7 @@ def _generate_stats(self, data_list: list[float]) -> UtilizationStats: return UtilizationStats( avg=round(avg, 2), max=round(maxi, 2), + raw=data_list, ) def _output_data(self) -> None: @@ -338,20 +342,33 @@ def _calculate_gpu_utilization(self, data_list: list[UsageData]) -> list[GpuUsag calculate_gpu = [] gpu_mem_utilization = defaultdict(list) gpu_utilization = defaultdict(list) + gpu_allocated_mem = defaultdict(list) + gpu_allocated_mem_values = defaultdict(list) + gpu_total_mem_values = defaultdict(float) for data in data_list: for gpu in data.gpu_list: gpu_mem_utilization[gpu.uuid].append(gpu.mem_utilization) gpu_utilization[gpu.uuid].append(gpu.utilization) + gpu_allocated_mem[gpu.uuid].append(gpu.allocated_mem) + gpu_allocated_mem_values[gpu.uuid].append(gpu.allocated_mem_value) + gpu_total_mem_values[gpu.uuid] = gpu.total_mem_value for gpu_uuid in gpu_utilization.keys(): gpu_util_stats = self._generate_stats(gpu_utilization[gpu_uuid]) gpu_mem_util_stats = self._generate_stats(gpu_mem_utilization[gpu_uuid]) + gpu_allocated_mem_stats = self._generate_stats(gpu_allocated_mem[gpu_uuid]) + gpu_allocated_mem_value_stats = self._generate_stats( + gpu_allocated_mem_values[gpu_uuid] + ) calculate_gpu.append( GpuUsage( uuid=gpu_uuid, util_percent=gpu_util_stats, mem_util_percent=gpu_mem_util_stats, + allocated_mem_percent=gpu_allocated_mem_stats, + allocated_mem_value=gpu_allocated_mem_value_stats, + total_mem_value=gpu_total_mem_values[gpu_uuid], ) ) return calculate_gpu @@ -382,11 +399,21 @@ def _collect_gpu_data(self) -> list[GpuData]: # see https://docs.nvidia.com/deploy/nvml-api/group__nvmlDeviceQueries.html gpu_utilization = pynvml.nvmlDeviceGetUtilizationRates(gpu_handle) gpu_uuid = pynvml.nvmlDeviceGetUUID(gpu_handle) + gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(gpu_handle) + mem_utilization = gpu_utilization.memory + + allocate_mem_MB = gpu_memory_info.used / 1024**2 + total_mem_MB = gpu_memory_info.total / 1024**2 + allocate_mem_percent = allocate_mem_MB / total_mem_MB * 100 + gpu_data_list.append( GpuData( uuid=gpu_uuid, utilization=gpu_utilization.gpu, - mem_utilization=gpu_utilization.memory, + mem_utilization=mem_utilization, + allocated_mem=allocate_mem_percent, + allocated_mem_value=allocate_mem_MB, + total_mem_value=total_mem_MB, ) ) elif self._has_amdsmi: @@ -397,11 +424,20 @@ def _collect_gpu_data(self) -> list[GpuData]: gpu_uuid = amdsmi.amdsmi_get_gpu_device_uuid(handle) gpu_utilization = engine_usage["gfx_activity"] gpu_mem_utilization = gpu_utilization["umc_activity"] + mem_info = amdsmi.amdsmi_get_gpu_memory_usage(handle) + + allocate_mem_MB = mem_info["vram_usage"] / 1024**2 + total_mem_MB = mem_info["vram_total"] / 1024**2 + allocate_mem_percent = allocate_mem_MB / total_mem_MB * 100 + gpu_data_list.append( GpuData( uuid=gpu_uuid, utilization=gpu_utilization, mem_utilization=gpu_mem_utilization, + allocated_mem=allocate_mem_percent, + allocated_mem_value=allocate_mem_MB, + total_mem_value=total_mem_MB, ) ) return gpu_data_list @@ -499,7 +535,9 @@ def get_processes_running_python_tests() -> list[Any]: cmd = " ".join(process.cmdline()) processName = process.name() pid = process.pid - if "python" in processName and cmd.startswith("python"): + is_python = "python" in processName and "python" in cmd + is_pytest = "pytest" in cmd + if is_python or is_pytest: python_test_processes.append({"pid": pid, "cmd": cmd}) except Exception: pass diff --git a/tools/stats/utilization_stats_lib.py b/tools/stats/utilization_stats_lib.py index 740fe71f1768..33551fd55de5 100644 --- a/tools/stats/utilization_stats_lib.py +++ b/tools/stats/utilization_stats_lib.py @@ -5,7 +5,7 @@ from dataclasses_json import DataClassJsonMixin -_DATA_MODEL_VERSION = 1.0 +_DATA_MODEL_VERSION = 1.5 # data model for test log usage @@ -13,6 +13,7 @@ class UtilizationStats: avg: Optional[float] = None max: Optional[float] = None + raw: Optional[list[float]] = None @dataclass @@ -36,6 +37,9 @@ class GpuUsage(DataClassJsonMixin): uuid: Optional[str] = None util_percent: Optional[UtilizationStats] = None mem_util_percent: Optional[UtilizationStats] = None + allocated_mem_percent: Optional[UtilizationStats] = None + allocated_mem_value: Optional[UtilizationStats] = None + total_mem_value: Optional[float] = None @dataclass diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index bc92f97b3956..8d761068d1e6 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -29,7 +29,6 @@ endif() set(LIBSHM_SRCDIR ${TORCH_SRC_DIR}/lib/${LIBSHM_SUBDIR}) add_subdirectory(${LIBSHM_SRCDIR}) - # Generate files set(TOOLS_PATH "${TORCH_ROOT}/tools") diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 69d90d4e7a1f..dea17d26ef21 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -40,6 +40,7 @@ from torch._C import ( ) from torch._prims_common import DeviceLikeType from torch.autograd.graph import Node as _Node +from torch.cuda import _POOL_HANDLE from torch.fx.node import Node as FxNode from torch.package import PackageExporter from torch.storage import TypedStorage, UntypedStorage @@ -1300,9 +1301,20 @@ def _initCrashHandler() -> None: ... # NB: There is no Capsule type in typing, see # https://github.com/python/cpython/issues/109562 -def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack -def _to_dlpack_versioned(data: Tensor) -> Any: ... # THPModule_toDLPackVersioned +def _to_dlpack( + data: Tensor, + dl_device: tuple[IntEnum, _int] | None = None, + copy: _bool | None = None, +) -> Any: ... # THPModule_toDLPack +def _to_dlpack_versioned( + data: Tensor, + dl_device: tuple[IntEnum, _int] | None = None, + copy: _bool | None = None, +) -> Any: ... # THPModule_toDLPackVersioned def _from_dlpack(data: Any) -> Tensor: ... # THPModule_fromDLPack +def _torchDeviceToDLDevice( + device: torch.device, +) -> tuple[_int, _int]: ... # THPModule_torchDeviceToDLDevice def _get_cpp_backtrace( frames_to_skip: _int, maximum_number_of_frames: _int, @@ -2289,7 +2301,7 @@ class _CUDAGraph: def __new__(cls, keep_graph: _bool = ...) -> Self: ... def capture_begin( self, - pool: tuple[_int, _int] | None = ..., + pool: _POOL_HANDLE | None = ..., capture_error_mode: str = "global", ) -> None: ... def capture_end(self) -> None: ... @@ -2297,7 +2309,7 @@ class _CUDAGraph: def register_generator_state(self, Generator) -> None: ... def replay(self) -> None: ... def reset(self) -> None: ... - def pool(self) -> tuple[_int, _int]: ... + def pool(self) -> _POOL_HANDLE: ... def enable_debug_mode(self) -> None: ... def debug_dump(self, debug_path: str) -> None: ... def raw_cuda_graph(self) -> _int: ... diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 2efe44c86b55..20805d56e370 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -350,6 +350,21 @@ class ProcessGroup: ) -> None: ... def rank(self) -> int: ... def size(self) -> int: ... + def split_group( + self, + new_ranks: list[int], + timeout: Optional[timedelta] = None, + pg_options: Optional[Backend.Options] = None, + group_desc: Optional[str] = None, + ) -> Optional[ProcessGroup]: ... + def merge_remote_group( + self, + store: Store, + size: int, + timeout: timedelta, + group_name: Optional[str] = None, + group_desc: Optional[str] = None, + ) -> ProcessGroup: ... def abort(self) -> None: ... def set_timeout(self, timeout: timedelta) -> None: ... def shutdown(self) -> None: ... diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index 129984e6c10d..6261679dcdef 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -58,7 +58,7 @@ class _PyInterpreterFrame: f_globals: dict[str, object] f_builtins: dict[str, object] f_lasti: int - f_lineo: int + f_lineno: int f_back: types.FrameType # A tuple containing cell objects captured by this frame. closure: tuple[types.CellType] diff --git a/torch/__init__.py b/torch/__init__.py index 95459337c2ed..99cb83db84b8 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -367,14 +367,8 @@ def _load_global_deps() -> None: "nccl": "libnccl.so.*[0-9]", "nvtx": "libnvToolsExt.so.*[0-9]", "nvshmem": "libnvshmem_host.so.*[0-9]", + "cufile": "libcufile.so.*[0-9]", } - # cufiile is only available on cuda 12+ - # TODO: Remove once CUDA 11.8 binaries are deprecated - if cuda_version is not None: - t_version = cuda_version.split(".") - t_major = int(t_version[0]) # type: ignore[operator] - if t_major >= 12: - cuda_libs["cufile"] = "libcufile.so.*[0-9]" is_cuda_lib_err = [ lib for lib in cuda_libs.values() if lib.split(".")[0] in err.args[0] diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index abb94b109cc0..8e9796d2f7c1 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -418,6 +418,7 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.native_dropout_backward, aten.native_group_norm_backward, aten.native_layer_norm_backward, + aten._fused_rms_norm_backward, aten.new_empty, aten.new_full, aten.new_ones, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index f93a0bf84fb4..634c4b6b4954 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1667,9 +1667,9 @@ def native_layer_norm_backward( N = prod(inner_dims) # type: ignore[arg-type] M = prod(outer_dims) # type: ignore[arg-type] - from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + from torch.fx.experimental.symbolic_shapes import statically_known_true - if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0): + if statically_known_true(M == 0) or statically_known_true(N == 0): return ( input.new_zeros(input_shape) if output_mask[0] else None, input.new_zeros(input_shape[axis:]) if output_mask[1] else None, @@ -1743,6 +1743,81 @@ def native_layer_norm_backward_out( return grad_input +@register_decomposition(aten._fused_rms_norm_backward.default) +def _fused_rms_norm_backward( + grad_out: Tensor, + input: Tensor, + normalized_shape: list[int], + rstd: Tensor, + weight: Optional[Tensor], + output_mask: list[bool], +) -> tuple[Optional[Tensor], Optional[Tensor]]: + input_shape = input.shape + input_ndim = input.dim() + computation_dtype = utils.get_computation_dtype(input.dtype) + + grad_out_cast = grad_out.to( + computation_dtype, memory_format=torch.contiguous_format + ) + input_cast = input.to(computation_dtype, memory_format=torch.contiguous_format) + weight_cast = ( + weight.to(computation_dtype, memory_format=torch.contiguous_format) + if weight is not None + else None + ) + assert grad_out_cast is not None + + axis = input_ndim - len(normalized_shape) + inner_dims = input_shape[axis:] + outer_dims = input_shape[:axis] + inner_dim_indices: list[int] = [] + outer_dim_indices: list[int] = [] + for i in range(input_ndim): + if i >= axis: + inner_dim_indices.append(i) + else: + outer_dim_indices.append(i) + + N = prod(inner_dims) # type: ignore[arg-type] + M = prod(outer_dims) # type: ignore[arg-type] + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0): + return ( + input.new_zeros(input_shape) if output_mask[0] else None, + input.new_zeros(input_shape[axis:]) if output_mask[1] else None, + ) + + rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr] + if weight_cast is not None: + grad_x_hat = grad_out_cast * weight_cast + else: + grad_x_hat = grad_out_cast + + d_input: Optional[Tensor] = None + d_weight: Optional[Tensor] = None + + x_hat = input_cast * rstd + + if output_mask[0]: + sum_val = torch.sum(x_hat * grad_x_hat, dim=inner_dim_indices, keepdim=True) + d_input = (grad_x_hat - (x_hat / N) * sum_val) * rstd + + if output_mask[1] and weight_cast is not None: + d_weight_full_shape = grad_out_cast * x_hat + if len(outer_dim_indices) > 0: + d_weight = torch.sum( + d_weight_full_shape, dim=outer_dim_indices, keepdim=False + ) + else: + d_weight = d_weight_full_shape + + return ( + _maybe_cast(d_input, input.dtype), + _maybe_cast(d_weight, input.dtype), + ) + + def native_batch_norm_helper( input: Tensor, weight: Optional[Tensor], diff --git a/torch/_dynamo/backends/onnxrt.py b/torch/_dynamo/backends/onnxrt.py index 6830c0409620..71c5e1765810 100644 --- a/torch/_dynamo/backends/onnxrt.py +++ b/torch/_dynamo/backends/onnxrt.py @@ -4,35 +4,38 @@ # to the right people, please tag related GitHub issues with `module: onnx`. # # Maintainers' Github IDs: wschin, xadupre -from torch.onnx._internal.onnxruntime import ( - is_onnxrt_backend_supported, - torch_compile_backend, -) - -from .registry import register_backend - - -def has_onnxruntime(): - # FIXME: update test/dynamo/test_backends.py to call is_onnxrt_backend_supported() - return is_onnxrt_backend_supported() - - -if is_onnxrt_backend_supported(): - register_backend(name="onnxrt", compiler_fn=torch_compile_backend) -else: - - def information_displaying_backend(*args, **kwargs): - raise ImportError( - "onnxrt is not registered as a backend. " - "Please make sure all dependencies such as " - "numpy, onnx, onnxscript, and onnxruntime-training are installed. " - "Suggested procedure to fix dependency problem:\n" - " (1) pip or conda install numpy onnx onnxscript onnxruntime-training.\n" - " (2) Open a new python terminal.\n" - " (3) Call the API `torch.onnx.is_onnxrt_backend_supported()`:\n" - " (4) If it returns `True`, then you can use `onnxrt` backend.\n" - " (5) If it returns `False`, please execute the package importing section in " - "torch/onnx/_internal/onnxruntime.py under pdb line-by-line to see which import fails." - ) - - register_backend(name="onnxrt", compiler_fn=information_displaying_backend) +# from torch.onnx._internal.onnxruntime import ( +# is_onnxrt_backend_supported, +# torch_compile_backend, +# ) + +# from .registry import register_backend + +""" +Placeholder for onnxruntime backend for dynamo +""" + +# def has_onnxruntime(): +# # FIXME: update test/dynamo/test_backends.py to call is_onnxrt_backend_supported() +# return is_onnxrt_backend_supported() + + +# if is_onnxrt_backend_supported(): +# register_backend(name="onnxrt", compiler_fn=torch_compile_backend) +# else: + +# def information_displaying_backend(*args, **kwargs): +# raise ImportError( +# "onnxrt is not registered as a backend. " +# "Please make sure all dependencies such as " +# "numpy, onnx, onnxscript, and onnxruntime-training are installed. " +# "Suggested procedure to fix dependency problem:\n" +# " (1) pip or conda install numpy onnx onnxscript onnxruntime-training.\n" +# " (2) Open a new python terminal.\n" +# " (3) Call the API `torch.onnx.is_onnxrt_backend_supported()`:\n" +# " (4) If it returns `True`, then you can use `onnxrt` backend.\n" +# " (5) If it returns `False`, please execute the package importing section in " +# "torch/onnx/_internal/onnxruntime.py under pdb line-by-line to see which import fails." +# ) + +# register_backend(name="onnxrt", compiler_fn=information_displaying_backend) diff --git a/torch/_dynamo/bytecode_analysis.py b/torch/_dynamo/bytecode_analysis.py index 3252ea91409f..8bdf155e0060 100644 --- a/torch/_dynamo/bytecode_analysis.py +++ b/torch/_dynamo/bytecode_analysis.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ This module provides utilities for analyzing and optimizing Python bytecode. Key functionality includes: @@ -18,8 +16,13 @@ import dataclasses import dis import sys -from typing import Any, Union +from typing import Any, TYPE_CHECKING, Union + +if TYPE_CHECKING: + # TODO(lucaskabela): consider moving Instruction into this file + # and refactoring in callsite; that way we don't have to guard this import + from .bytecode_transformation import Instruction TERMINAL_OPCODES = { dis.opmap["RETURN_VALUE"], @@ -33,7 +36,7 @@ TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"]) else: TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"]) -if sys.version_info >= (3, 12): +if (3, 12) <= sys.version_info < (3, 14): TERMINAL_OPCODES.add(dis.opmap["RETURN_CONST"]) if sys.version_info >= (3, 13): TERMINAL_OPCODES.add(dis.opmap["JUMP_BACKWARD_NO_INTERRUPT"]) @@ -45,7 +48,7 @@ stack_effect = dis.stack_effect -def get_indexof(insts): +def get_indexof(insts: list["Instruction"]) -> dict["Instruction", int]: """ Get a mapping from instruction memory address to index in instruction list. Additionally checks that each instruction only appears once in the list. @@ -57,12 +60,12 @@ def get_indexof(insts): return indexof -def remove_dead_code(instructions): +def remove_dead_code(instructions: list["Instruction"]) -> list["Instruction"]: """Dead code elimination""" indexof = get_indexof(instructions) live_code = set() - def find_live_code(start): + def find_live_code(start: int) -> None: for i in range(start, len(instructions)): if i in live_code: return @@ -71,6 +74,7 @@ def find_live_code(start): if inst.exn_tab_entry: find_live_code(indexof[inst.exn_tab_entry.target]) if inst.opcode in JUMP_OPCODES: + assert inst.target is not None find_live_code(indexof[inst.target]) if inst.opcode in TERMINAL_OPCODES: return @@ -102,7 +106,7 @@ def find_live_code(start): return [inst for i, inst in enumerate(instructions) if i in live_code] -def remove_pointless_jumps(instructions): +def remove_pointless_jumps(instructions: list["Instruction"]) -> list["Instruction"]: """Eliminate jumps to the next instruction""" pointless_jumps = { id(a) @@ -112,11 +116,11 @@ def remove_pointless_jumps(instructions): return [inst for inst in instructions if id(inst) not in pointless_jumps] -def propagate_line_nums(instructions): +def propagate_line_nums(instructions: list["Instruction"]) -> None: """Ensure every instruction has line number set in case some are removed""" cur_line_no = None - def populate_line_num(inst): + def populate_line_num(inst: "Instruction") -> None: nonlocal cur_line_no if inst.starts_line: cur_line_no = inst.starts_line @@ -127,12 +131,12 @@ def populate_line_num(inst): populate_line_num(inst) -def remove_extra_line_nums(instructions): +def remove_extra_line_nums(instructions: list["Instruction"]) -> None: """Remove extra starts line properties before packing bytecode""" cur_line_no = None - def remove_line_num(inst): + def remove_line_num(inst: "Instruction") -> None: nonlocal cur_line_no if inst.starts_line is None: return @@ -152,12 +156,14 @@ class ReadsWrites: visited: set[Any] -def livevars_analysis(instructions, instruction): +def livevars_analysis( + instructions: list["Instruction"], instruction: "Instruction" +) -> set[Any]: indexof = get_indexof(instructions) must = ReadsWrites(set(), set(), set()) may = ReadsWrites(set(), set(), set()) - def walk(state, start): + def walk(state: ReadsWrites, start: int) -> None: if start in state.visited: return state.visited.add(start) @@ -177,6 +183,7 @@ def walk(state, start): if inst.exn_tab_entry: walk(may, indexof[inst.exn_tab_entry.target]) if inst.opcode in JUMP_OPCODES: + assert inst.target is not None walk(may, indexof[inst.target]) state = may if inst.opcode in TERMINAL_OPCODES: @@ -197,19 +204,19 @@ class StackSize: high: Union[int, float] fixed_point: FixedPointBox - def zero(self): + def zero(self) -> None: self.low = 0 self.high = 0 self.fixed_point.value = False - def offset_of(self, other, n): + def offset_of(self, other: "StackSize", n: int) -> None: prior = (self.low, self.high) self.low = min(self.low, other.low + n) self.high = max(self.high, other.high + n) if (self.low, self.high) != prior: self.fixed_point.value = False - def exn_tab_jump(self, depth): + def exn_tab_jump(self, depth: int) -> None: prior = (self.low, self.high) self.low = min(self.low, depth) self.high = max(self.high, depth) @@ -217,7 +224,7 @@ def exn_tab_jump(self, depth): self.fixed_point.value = False -def stacksize_analysis(instructions) -> Union[int, float]: +def stacksize_analysis(instructions: list["Instruction"]) -> Union[int, float]: assert instructions fixed_point = FixedPointBox() stack_sizes = { @@ -238,6 +245,7 @@ def stacksize_analysis(instructions) -> Union[int, float]: eff = stack_effect(inst.opcode, inst.arg, jump=False) stack_sizes[next_inst].offset_of(stack_size, eff) if inst.opcode in JUMP_OPCODES: + assert inst.target is not None, f"missing target: {inst}" stack_sizes[inst.target].offset_of( stack_size, stack_effect(inst.opcode, inst.arg, jump=True) ) @@ -247,11 +255,6 @@ def stacksize_analysis(instructions) -> Union[int, float]: depth = inst.exn_tab_entry.depth + int(inst.exn_tab_entry.lasti) + 1 stack_sizes[inst.exn_tab_entry.target].exn_tab_jump(depth) - if False: - for inst in instructions: - stack_size = stack_sizes[inst] - print(stack_size.low, stack_size.high, inst) - low = min(x.low for x in stack_sizes.values()) high = max(x.high for x in stack_sizes.values()) diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py index 9226a61577d8..165182d93d23 100644 --- a/torch/_dynamo/bytecode_transformation.py +++ b/torch/_dynamo/bytecode_transformation.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ This module provides utilities for analyzing, transforming and manipulating Python bytecode. It includes functionality for: @@ -23,7 +21,7 @@ import sys import types import uuid -from collections.abc import Iterator, Sequence +from collections.abc import Iterable, Iterator, Mapping, Sequence from typing import Any, Callable, cast, Optional, Union from ..utils._backport_slots import dataclass_slots @@ -53,7 +51,9 @@ def __repr__(self) -> str: f"depth={self.depth}, lasti={self.lasti})" ) - def __eq__(self, o) -> bool: + def __eq__(self, o: object) -> bool: + if not isinstance(o, InstructionExnTabEntry): + return False return ( self.start is o.start and self.end is o.end @@ -84,7 +84,7 @@ class Instruction: def __hash__(self) -> int: return id(self) - def __eq__(self, other) -> bool: + def __eq__(self, other: object) -> bool: return id(self) == id(other) def short_inst_repr(self) -> str: @@ -145,22 +145,26 @@ def __repr__(self) -> str: if sys.version_info >= (3, 12): - def inst_has_op_bits(name): + def inst_has_op_bits(name: str) -> bool: return name in ("LOAD_ATTR", "LOAD_GLOBAL", "LOAD_SUPER_ATTR") elif sys.version_info >= (3, 11): - def inst_has_op_bits(name): + def inst_has_op_bits(name: str) -> bool: return name == "LOAD_GLOBAL" else: - def inst_has_op_bits(name): + def inst_has_op_bits(name: str): return False def create_instruction( - name, *, arg=None, argval=_NotProvided, target=None + name: str, + *, + arg: Optional[int] = None, + argval: Optional[Any] = _NotProvided, + target: Optional[Instruction] = None, ) -> Instruction: """ At most one of `arg`, `argval`, and `target` can be not None/_NotProvided. @@ -198,12 +202,12 @@ def create_instruction( # Python 3.11 remaps -def create_jump_absolute(target) -> Instruction: +def create_jump_absolute(target: Instruction) -> Instruction: inst = "JUMP_FORWARD" if sys.version_info >= (3, 11) else "JUMP_ABSOLUTE" return create_instruction(inst, target=target) -def create_load_const(val, checked=True) -> Instruction: +def create_load_const(val: Any, checked: bool = True) -> Instruction: """ In general we should only create `LOAD_CONST` for immutable objects, but sometimes it's convenient _and safe_ for Dynamo create `LOAD_CONST` for @@ -220,7 +224,7 @@ def create_dup_top() -> Instruction: return create_instruction("DUP_TOP") -def create_rot_n(n) -> list[Instruction]: +def create_rot_n(n: int) -> list[Instruction]: """ Returns a "simple" sequence of instructions that rotates TOS to the n-th position in the stack. For Python < 3.11, returns a single ROT_* @@ -265,17 +269,18 @@ def add_push_null( In this case, instructions WILL be modified. """ if isinstance(inst_or_insts, Instruction): - insts = [inst_or_insts] + insts: list[Instruction] = [inst_or_insts] else: + assert isinstance(inst_or_insts, list) insts = inst_or_insts - def inst_has_bit_set(idx): + def inst_has_bit_set(idx: int) -> bool: assert insts[idx].arg is not None - return insts[idx].arg & 1 == 1 + return insts[idx].arg & 1 == 1 # type: ignore[operator] - def set_inst_bit(idx): + def set_inst_bit(idx: int) -> None: assert insts[idx].arg is not None - insts[idx].arg |= 1 + insts[idx].arg |= 1 # type: ignore[operator] if sys.version_info >= (3, 13): # In 3.13, NULL follows the callable @@ -312,8 +317,9 @@ def add_push_null_call_function_ex( is not set, due to an expected CALL_FUNCTION_EX instruction. """ if isinstance(inst_or_insts, Instruction): - insts = [inst_or_insts] + insts: list[Instruction] = [inst_or_insts] else: + assert isinstance(inst_or_insts, list) insts = inst_or_insts if sys.version_info < (3, 11): @@ -334,7 +340,7 @@ def add_push_null_call_function_ex( return insts -def create_call_function(nargs, push_null) -> list[Instruction]: +def create_call_function(nargs: int, push_null: bool) -> list[Instruction]: """ Creates a sequence of instructions that makes a function call. @@ -389,7 +395,7 @@ def create_call_function(nargs, push_null) -> list[Instruction]: return [create_instruction("CALL_FUNCTION", arg=nargs)] -def create_call_method(nargs) -> list[Instruction]: +def create_call_method(nargs: int) -> list[Instruction]: if sys.version_info >= (3, 12): return [create_instruction("CALL", arg=nargs)] if sys.version_info >= (3, 11): @@ -400,19 +406,19 @@ def create_call_method(nargs) -> list[Instruction]: return [create_instruction("CALL_METHOD", arg=nargs)] -def create_load_method(name) -> Instruction: +def create_load_method(name: str) -> Instruction: if sys.version_info >= (3, 12): # in 3.12, create a LOAD_ATTR instruction with the low bit set return create_instruction("LOAD_ATTR", arg=1, argval=name) return create_instruction("LOAD_METHOD", argval=name) -def create_setup_with(target) -> Instruction: +def create_setup_with(target: Instruction) -> Instruction: opname = "BEFORE_WITH" if sys.version_info >= (3, 11) else "SETUP_WITH" return create_instruction(opname, target=target) -def create_swap(n) -> list[Instruction]: +def create_swap(n: int) -> list[Instruction]: if sys.version_info >= (3, 11): return [create_instruction("SWAP", arg=n)] # in Python < 3.11, SWAP is a macro that expands to multiple instructions @@ -465,7 +471,7 @@ def lnotab_writer( assert sys.version_info < (3, 10) lnotab: list[int] = [] - def update(lineno_new, byteno_new): + def update(lineno_new: int, byteno_new: int) -> None: nonlocal byteno, lineno while byteno_new != byteno or lineno_new != lineno: byte_offset = max(0, min(byteno_new - byteno, 255)) @@ -478,7 +484,9 @@ def update(lineno_new, byteno_new): return lnotab, update -def linetable_310_writer(first_lineno): +def linetable_310_writer( + first_lineno: int, +) -> tuple[list[int], Callable[[int, int], None], Callable[[int], None]]: """ Used to create typing.CodeType.co_linetable See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt @@ -490,7 +498,7 @@ def linetable_310_writer(first_lineno): lineno_delta = 0 byteno = 0 - def _update(byteno_delta, lineno_delta): + def _update(byteno_delta: int, lineno_delta: int) -> None: while byteno_delta != 0 or lineno_delta != 0: byte_offset = max(0, min(byteno_delta, 254)) line_offset = max(-127, min(lineno_delta, 127)) @@ -499,7 +507,7 @@ def _update(byteno_delta, lineno_delta): lineno_delta -= line_offset linetable.extend((byte_offset, line_offset & 0xFF)) - def update(lineno_new, byteno_new): + def update(lineno_new: int, byteno_new: int) -> None: nonlocal lineno, lineno_delta, byteno byteno_delta = byteno_new - byteno byteno = byteno_new @@ -507,7 +515,7 @@ def update(lineno_new, byteno_new): lineno_delta = lineno_new - lineno lineno = lineno_new - def end(total_bytes): + def end(total_bytes: int) -> None: _update(total_bytes - byteno, lineno_delta) return linetable, update, end @@ -528,7 +536,9 @@ def encode_varint(n: int) -> list[int]: return b -def linetable_311_writer(first_lineno: int): +def linetable_311_writer( + first_lineno: int, +) -> tuple[list[int], Callable[[Optional["dis.Positions"], int], None]]: """ Used to create typing.CodeType.co_linetable See https://github.com/python/cpython/blob/3.11/Objects/locations.md @@ -538,11 +548,11 @@ def linetable_311_writer(first_lineno: int): linetable = [] lineno = first_lineno - def update(positions: "dis.Positions", inst_size): + def update(positions: Optional["dis.Positions"], inst_size: int) -> None: nonlocal lineno lineno_new = positions.lineno if positions else None - def _update(delta, size): + def _update(delta: int, size: int) -> None: assert 0 < size <= 8 # first byte - use 13 (no column info) is positions is # malformed, otherwise use 14 (long form) @@ -721,7 +731,9 @@ def assemble(instructions: list[Instruction], firstlineno: int) -> tuple[bytes, return bytes(code), bytes(lnotab) -def _get_instruction_by_offset(offset_to_inst: dict[int, Instruction], offset: int): +def _get_instruction_by_offset( + offset_to_inst: dict[int, Instruction], offset: int +) -> Optional[Instruction]: """ Get the instruction located at a given offset, accounting for EXTENDED_ARGs """ @@ -731,9 +743,11 @@ def _get_instruction_by_offset(offset_to_inst: dict[int, Instruction], offset: i return None -def virtualize_jumps(instructions) -> None: +def virtualize_jumps(instructions: Iterable[Instruction]) -> None: """Replace jump targets with pointers to make editing easier""" - jump_targets = {inst.offset: inst for inst in instructions} + jump_targets = { + inst.offset: inst for inst in instructions if inst.offset is not None + } for inst in instructions: if inst.opcode in dis.hasjabs or inst.opcode in dis.hasjrel: @@ -756,7 +770,7 @@ def flip_jump_direction(instruction: Instruction) -> None: assert instruction.opcode in _REL_JUMPS -def _get_instruction_front(instructions: list[Instruction], idx: int): +def _get_instruction_front(instructions: list[Instruction], idx: int) -> Instruction: """ i.e. get the first EXTENDED_ARG instruction (if any) when targeting instructions[idx] with a jump. @@ -770,7 +784,7 @@ def _get_instruction_front(instructions: list[Instruction], idx: int): return target -def devirtualize_jumps(instructions): +def devirtualize_jumps(instructions: list[Instruction]) -> None: """Fill in args for virtualized jump target after instructions may have moved""" jumps = set(dis.hasjabs).union(set(dis.hasjrel)) @@ -778,6 +792,11 @@ def devirtualize_jumps(instructions): for inst in instructions: if inst.opcode in jumps: if inst.opcode not in dis.hasjabs: + assert ( + inst.target is not None + and inst.target.offset is not None + and inst.offset is not None + ) if inst.target.offset < inst.offset: if sys.version_info < (3, 11): raise RuntimeError("Got negative jump offset for Python < 3.11") @@ -796,6 +815,7 @@ def devirtualize_jumps(instructions): # compute jump instruction arg for inst in instructions: if inst.opcode in jumps: + assert inst.target is not None target = _get_instruction_front(instructions, indexof[inst.target]) if inst.opcode in dis.hasjabs: if sys.version_info < (3, 10): @@ -808,6 +828,7 @@ def devirtualize_jumps(instructions): raise RuntimeError("Python 3.11+ should not have absolute jumps") else: # relative jump # byte offset between target and next instruction + assert target.offset is not None and inst.offset is not None inst.arg = abs( int(target.offset - inst.offset - instruction_size(inst)) ) @@ -818,7 +839,9 @@ def devirtualize_jumps(instructions): inst.argrepr = f"to {target.offset}" -def virtualize_exception_table(exn_tab_bytes: bytes, instructions: list[Instruction]): +def virtualize_exception_table( + exn_tab_bytes: bytes, instructions: list[Instruction] +) -> None: """Replace exception table entries with pointers to make editing easier""" exn_tab = parse_exception_table(exn_tab_bytes) offset_to_inst = {cast(int, inst.offset): inst for inst in instructions} @@ -827,7 +850,7 @@ def virtualize_exception_table(exn_tab_bytes: bytes, instructions: list[Instruct exn_tab_iter = iter(exn_tab) try: - def step(): + def step() -> tuple[ExceptionTableEntry, InstructionExnTabEntry]: nonlocal end_offset_idx entry = next(exn_tab_iter) # find rightmost offset <= entry.end, since entry.end may not be @@ -841,9 +864,9 @@ def step(): assert end_offset_idx > 0 end_offset = offsets[end_offset_idx - 1] inst_entry = InstructionExnTabEntry( - _get_instruction_by_offset(offset_to_inst, entry.start), - _get_instruction_by_offset(offset_to_inst, end_offset), - _get_instruction_by_offset(offset_to_inst, entry.target), + _get_instruction_by_offset(offset_to_inst, entry.start), # type: ignore[arg-type] + _get_instruction_by_offset(offset_to_inst, end_offset), # type: ignore[arg-type] + _get_instruction_by_offset(offset_to_inst, entry.target), # type: ignore[arg-type] entry.depth, entry.lasti, ) @@ -851,6 +874,7 @@ def step(): entry, inst_entry = step() for inst in instructions: + assert inst.offset is not None while inst.offset > entry.end: entry, inst_entry = step() if inst.offset >= entry.start: @@ -872,15 +896,18 @@ def compute_exception_table( start = _get_instruction_front( instructions, indexof[inst.exn_tab_entry.start] ).offset + assert start is not None # point to the last 2 bytes of the end instruction end = ( cast(int, inst.exn_tab_entry.end.offset) + instruction_size(inst.exn_tab_entry.end) - 2 ) + assert end is not None target = _get_instruction_front( instructions, indexof[inst.exn_tab_entry.target] ).offset + assert target is not None key = (start, end) val = (target, inst.exn_tab_entry.depth, inst.exn_tab_entry.lasti) if key in exn_dict: @@ -900,7 +927,7 @@ def compute_exception_table( key_stack: list[tuple[int, int]] = [] exn_tab: list[ExceptionTableEntry] = [] - def pop(): + def pop() -> None: """ Pop the key_stack and append an exception table entry if possible. """ @@ -934,7 +961,7 @@ def pop(): def check_inst_exn_tab_entries_nested( - tab: list[InstructionExnTabEntry], indexof + tab: list[InstructionExnTabEntry], indexof: dict[Instruction, int] ) -> None: """ Checks `tab` is a properly sorted list of nested InstructionExnTabEntry's, @@ -979,7 +1006,7 @@ def propagate_inst_exn_table_entries(instructions: list[Instruction]) -> None: instructions[i].exn_tab_entry = copy.copy(entry) -def check_inst_exn_tab_entries_valid(instructions: list[Instruction]): +def check_inst_exn_tab_entries_valid(instructions: list[Instruction]) -> None: """ Checks that exn_tab_entries of instructions are valid. An entry's start, end, and target must be in instructions. @@ -1012,7 +1039,9 @@ def strip_extended_args(instructions: list[Instruction]) -> None: # instruction, exception table entries, and positions. # Returns the modified sequence of instructions (including the modified # old instruction!) that can be manipulated elsewhere. -def overwrite_instruction(old_inst, new_insts): +def overwrite_instruction( + old_inst: Instruction, new_insts: list[Instruction] +) -> list[Instruction]: # update old_inst.exnt_tab_entry.end if necessary if ( old_inst.exn_tab_entry @@ -1161,7 +1190,7 @@ def fix_extended_args(instructions: list[Instruction]) -> int: """Fill in correct argvals for EXTENDED_ARG ops""" output: list[Instruction] = [] - def maybe_pop_n(n): + def maybe_pop_n(n: int) -> None: for _ in range(n): if output and output[-1].opcode == dis.EXTENDED_ARG: output.pop() @@ -1190,7 +1219,7 @@ def maybe_pop_n(n): return added -def instruction_size(inst) -> int: +def instruction_size(inst: Instruction) -> int: import torch if sys.version_info >= (3, 11): @@ -1198,21 +1227,21 @@ def instruction_size(inst) -> int: return 2 -def check_offsets(instructions) -> None: +def check_offsets(instructions: Sequence[Instruction]) -> None: offset = 0 for inst in instructions: assert inst.offset == offset offset += instruction_size(inst) -def update_offsets(instructions) -> None: +def update_offsets(instructions: Sequence[Instruction]) -> None: offset = 0 for inst in instructions: inst.offset = offset offset += instruction_size(inst) -def debug_bytes(*args) -> str: +def debug_bytes(*args: bytes) -> str: index = range(max(map(len, args))) result = [ " ".join(f"{x:03}" for x in arg) @@ -1224,7 +1253,7 @@ def debug_bytes(*args) -> str: return "bytes mismatch\n" + "\n".join(result) -def debug_checks(code): +def debug_checks(code: types.CodeType) -> None: """Make sure our assembler produces same bytes as we start with""" dode = transform_code_object(code, lambda x, y: None, safe=True) assert code.co_code == dode.co_code, debug_bytes(code.co_code, dode.co_code) @@ -1237,7 +1266,7 @@ def debug_checks(code): HAS_CONST = set(dis.hasconst) -def get_const_index(code_options, val) -> int: +def get_const_index(code_options: dict[str, Any], val: Any) -> int: for i, v in enumerate(code_options["co_consts"]): # NOTE: stronger comparison is required, since we have # examples where two values compare equal but have @@ -1249,11 +1278,15 @@ def get_const_index(code_options, val) -> int: return len(code_options["co_consts"]) - 1 -def fix_vars(instructions: list[Instruction], code_options, varname_from_oparg=None): +def fix_vars( + instructions: list[Instruction], + code_options: dict[str, Any], + varname_from_oparg: Optional[Callable[..., Any]] = None, +) -> None: # compute instruction arg from argval if arg is not provided names = {name: idx for idx, name in enumerate(code_options["co_names"])} - def get_name_index(name) -> int: + def get_name_index(name: str) -> int: try: idx = names[name] except KeyError: @@ -1288,7 +1321,7 @@ def get_name_index(name) -> int: } for i in range(len(instructions)): - def should_compute_arg(): + def should_compute_arg() -> bool: # argval is prioritized over arg return instructions[i].argval is not _NotProvided @@ -1356,7 +1389,7 @@ def should_compute_arg(): instructions[i].arg = idx -def clear_instruction_args(instructions): +def clear_instruction_args(instructions: list[Instruction]) -> None: # Clear the instruction arg for instructions that have argvals. # Useful for using dis'd bytecode within generated bytecode. for inst in instructions: @@ -1413,7 +1446,11 @@ def get_code_keys() -> list[str]: return keys -def transform_code_object(code, transformations, safe=False) -> types.CodeType: +def transform_code_object( + code: types.CodeType, + transformations: Callable[[list[Instruction], dict[str, Any]], Any], + safe: bool = False, +) -> types.CodeType: keys = get_code_keys() code_options = {k: getattr(code, k) for k in keys} assert len(code_options["co_varnames"]) == code_options["co_nlocals"] @@ -1466,7 +1503,7 @@ def clean_and_assemble_instructions( return instructions, types.CodeType(*[code_options[k] for k in keys]) -def populate_kw_names_argval(instructions, consts): +def populate_kw_names_argval(instructions: Sequence[Instruction], consts: Any) -> None: for inst in instructions: if inst.opname == "KW_NAMES": inst.argval = consts[inst.arg] @@ -1474,7 +1511,7 @@ def populate_kw_names_argval(instructions, consts): # If safe=True, we do not make any bytecode modifications. # Mainly used for debugging bytecode_transformation (see debug_checks) -def cleaned_instructions(code, safe=False) -> list[Instruction]: +def cleaned_instructions(code: types.CodeType, safe: bool = False) -> list[Instruction]: instructions = _cached_cleaned_instructions(code, safe) # We have a lot of code that implicitly mutates the instruction array. We # could do better here by making the copies explicit when necessary. @@ -1482,7 +1519,7 @@ def cleaned_instructions(code, safe=False) -> list[Instruction]: # Copy an instructions array, making sure to remap the individual instruction targets. -def _clone_instructions(instructions): +def _clone_instructions(instructions: Sequence[Instruction]) -> list[Instruction]: # This is super hot and this is the fastest way to do this (tried copy.copy # and dataclasses.replace). copied = [ @@ -1504,10 +1541,10 @@ def _clone_instructions(instructions): remap = dict(zip(instructions, copied)) # Handle `None` in the remapper so we don't need an extra `if`. - remap[None] = None + remap[None] = None # type: ignore[index, assignment] for i in copied: - i.target = remap[i.target] + i.target = remap[i.target] # type: ignore[index] if entry := i.exn_tab_entry: i.exn_tab_entry = InstructionExnTabEntry( remap[entry.start], @@ -1520,7 +1557,9 @@ def _clone_instructions(instructions): @functools.lru_cache -def _cached_cleaned_instructions(code, safe=False) -> Sequence[Instruction]: +def _cached_cleaned_instructions( + code: types.CodeType, safe: bool = False +) -> Sequence[Instruction]: instructions = list(map(convert_instruction, dis.get_instructions(code))) check_offsets(instructions) if sys.version_info >= (3, 11): @@ -1548,7 +1587,7 @@ def _cached_cleaned_instructions(code, safe=False) -> Sequence[Instruction]: _unique_id_counter = itertools.count() -def unique_id(name, with_uuid=False) -> str: +def unique_id(name: str, with_uuid: bool = False) -> str: ret = f"{name}_{next(_unique_id_counter)}" if with_uuid: ret += f"_{uuid.uuid4()}".replace("-", "_") @@ -1560,7 +1599,12 @@ def is_generator(code: types.CodeType) -> bool: return (code.co_flags & co_generator) > 0 -def bytecode_from_template(fn, varname_map=None, noreturn=True, noprefix=True): +def bytecode_from_template( + fn: Callable[..., Any], + varname_map: Optional[Mapping[Any, Any]] = None, + noreturn: bool = True, + noprefix: bool = True, +) -> list[Instruction]: """Generates bytecode from a template function `fn` for use in dynamo bytecode generation. diff --git a/torch/_dynamo/cache_size.py b/torch/_dynamo/cache_size.py index cff7ea3fef33..d1a46742f37a 100644 --- a/torch/_dynamo/cache_size.py +++ b/torch/_dynamo/cache_size.py @@ -1,7 +1,7 @@ -# mypy: allow-untyped-defs import logging import weakref from dataclasses import dataclass +from typing import Any, Optional from torch._guards import CompileId @@ -9,7 +9,7 @@ from .types import DynamoFrameType -log = logging.getLogger(__name__) +log: logging.Logger = logging.getLogger(__name__) """ [Note on cache size limit] @@ -99,7 +99,9 @@ def will_compilation_exceed_specific_limit(self, limit: int) -> bool: return self.num_cache_entries_with_same_id_matched_objs >= limit -def _get_weakref_from_f_locals(frame: DynamoFrameType, local_name: str): +def _get_weakref_from_f_locals( + frame: DynamoFrameType, local_name: str +) -> Optional[weakref.ref[Any]]: obj = frame.f_locals.get(local_name, None) weak_id = None try: @@ -109,7 +111,7 @@ def _get_weakref_from_f_locals(frame: DynamoFrameType, local_name: str): return weak_id -def _has_same_id_matched_objs(frame: DynamoFrameType, cache_entry) -> bool: +def _has_same_id_matched_objs(frame: DynamoFrameType, cache_entry: Any) -> bool: """ Checks if the ID_MATCH'd objects saved on cache_entry are same as the ones in frame.f_locals. @@ -131,7 +133,7 @@ def _has_same_id_matched_objs(frame: DynamoFrameType, cache_entry) -> bool: def compute_cache_size( - frame: DynamoFrameType, cache_entry + frame: DynamoFrameType, cache_entry: Any ) -> CacheSizeRelevantForFrame: # Walk the linked list to calculate the cache size num_cache_entries = 0 diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index 946ad280570a..f64ef6e5231a 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ This module provides utilities for generating Python bytecode in PyTorch's Dynamo system. It includes functionality for: @@ -18,7 +16,8 @@ import sys import types from collections import Counter -from typing import Optional, TYPE_CHECKING, Union +from collections.abc import Iterable +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch.nn from torch.utils._ordered_set import OrderedSet @@ -55,6 +54,8 @@ if TYPE_CHECKING: + from torch._dynamo.variables.builder import GraphArg + from .symbolic_convert import InstructionTranslatorBase @@ -74,8 +75,8 @@ def __init__( tx: "InstructionTranslatorBase", root: Optional[torch.nn.Module] = None, graph_output_var: Optional[str] = None, - tempvars=None, - overridden_sources=None, + tempvars: Optional[dict[Union[VariableTracker, Source], Any]] = None, + overridden_sources: Optional[dict[Source, Source]] = None, ) -> None: self.root = root self.top_of_stack: Optional[Union[VariableTracker, Source]] = None @@ -86,7 +87,7 @@ def __init__( # locals, and maps the VariableTracker/Source to the local variable # name. Note that it could map to None initially, in which case we'll # overwrite it to map to real temporary names via `add_cache`. - self.tempvars = tempvars or {} + self.tempvars: dict[Union[VariableTracker, Source], Any] = tempvars or {} self.tx = tx self.graph_output_var = graph_output_var self.code_options = self.tx.output.code_options @@ -98,7 +99,9 @@ def __init__( # without affecting other components, e.g., guards. self.overridden_sources: dict[Source, Source] = overridden_sources or {} - def restore_stack(self, stack_values, *, value_from_source=True): + def restore_stack( + self, stack_values: list[Any], *, value_from_source: bool = True + ) -> None: prev = self.value_from_source self.value_from_source &= value_from_source try: @@ -106,14 +109,18 @@ def restore_stack(self, stack_values, *, value_from_source=True): finally: self.value_from_source = prev - def graph_output_vars(self): + def graph_output_vars(self) -> list[VariableTracker]: return [x.variable for x in self.graph_outputs.values()] - def call_reconstruct(self, value): + def call_reconstruct( + self, value: Union[VariableTracker, Source, "GraphArg"] + ) -> None: res = value.reconstruct(self) assert res is None, f"reconstruct!=None {value}" - def add_push_null(self, gen_fn, call_function_ex=False): + def add_push_null( + self, gen_fn: Callable[[], None], call_function_ex: bool = False + ) -> None: """ `gen_fn` generates instructions via PyCodegen methods that push a single callable to the stack. @@ -142,7 +149,9 @@ def add_push_null(self, gen_fn, call_function_ex=False): # NULL will be at top of stack self.clear_tos() - def __call__(self, value, allow_cache=True): + def __call__( + self, value: Union[VariableTracker, Source], allow_cache: bool = True + ) -> None: """ Generate code such that top-of-stack (TOS) is set to value. @@ -297,7 +306,7 @@ def __call__(self, value, allow_cache=True): value.as_tensor(self.tx, torch.float64) ) - def gen_fn(): + def gen_fn() -> None: self.load_graph_output(graph_outputs[graph_outputs_key].index) output.append(self.create_load_attr("item")) @@ -322,7 +331,7 @@ def gen_fn(): output.extend(create_call_function(1, False)) elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap: - def gen_fn(): + def gen_fn() -> None: self.load_graph_output(graph_outputs[graph_outputs_key].index) output.append(self.create_load_attr("item")) @@ -363,7 +372,7 @@ def gen_fn(): self.top_of_stack = value - def add_graph_output(self, value): + def add_graph_output(self, value: VariableTracker) -> int: graph_outputs_key = id(value.as_proxy()) if graph_outputs_key not in self.graph_outputs: self.graph_outputs[graph_outputs_key] = GraphOutputEntry( @@ -371,25 +380,26 @@ def add_graph_output(self, value): ) return graph_outputs_key - def load_graph_output(self, index): + def load_graph_output(self, index: int) -> None: output = self._output + assert self.graph_output_var is not None output.append(self.create_load(self.graph_output_var)) output.append(self.create_load_const(index)) output.append(self.create_binary_subscr()) - def add_cache(self, value): + def add_cache(self, value: Union[VariableTracker, Source]) -> None: var = self.new_var() self.tempvars[value] = var self._output.append(self.create_store(var)) - def foreach(self, items): + def foreach(self, items: Iterable[Union[VariableTracker, Source]]) -> None: for i in items: self(i) def create_binary_subscr(self) -> Instruction: return create_instruction("BINARY_SUBSCR") - def setup_globally_cached(self, name, value): + def setup_globally_cached(self, name: str, value: Any) -> list[Instruction]: """Store value in a new global""" name = re.sub(r"[^a-zA-Z0-9_]+", "_", name) f_globals = self.tx.f_globals @@ -399,15 +409,15 @@ def setup_globally_cached(self, name, value): f_globals[name] = value return [self.create_load_global(name, add=True)] - def clear_tos(self): + def clear_tos(self) -> None: self.top_of_stack = None - def append_output(self, inst): + def append_output(self, inst: Instruction) -> None: assert isinstance(inst, Instruction) self._output.append(inst) self.clear_tos() - def extend_output(self, insts): + def extend_output(self, insts: list[Instruction]) -> None: assert all(isinstance(x, Instruction) for x in insts) self._output.extend(insts) self.clear_tos() @@ -415,66 +425,68 @@ def extend_output(self, insts): def get_instructions(self) -> list[Instruction]: return self._output - def create_load(self, name) -> Instruction: + def create_load(self, name: str) -> Instruction: assert name in self.code_options["co_varnames"], f"{name} missing" return create_instruction("LOAD_FAST", argval=name) - def create_load_closure(self, name) -> Instruction: + def create_load_closure(self, name: str) -> Instruction: assert name in self.cell_and_freevars() inst_name = "LOAD_FAST" if sys.version_info >= (3, 13) else "LOAD_CLOSURE" return create_instruction(inst_name, argval=name) - def create_load_deref(self, name) -> Instruction: + def create_load_deref(self, name: str) -> Instruction: assert name in self.cell_and_freevars() return create_instruction("LOAD_DEREF", argval=name) - def create_store(self, name) -> Instruction: + def create_store(self, name: str) -> Instruction: assert name in self.code_options["co_varnames"], f"{name} missing" return create_instruction("STORE_FAST", argval=name) - def create_store_deref(self, name) -> Instruction: + def create_store_deref(self, name: str) -> Instruction: assert name in self.cell_and_freevars() return create_instruction("STORE_DEREF", argval=name) - def create_load_global(self, name, add=False) -> Instruction: + def create_load_global(self, name: str, add: bool = False) -> Instruction: if add: self.tx.output.update_co_names(name) assert name in self.code_options["co_names"], f"{name} not in co_names" return create_instruction("LOAD_GLOBAL", argval=name) - def create_load_const(self, value) -> Instruction: + def create_load_const(self, value: Any) -> Instruction: return create_load_const(value) - def create_load_const_unchecked(self, value) -> Instruction: + def create_load_const_unchecked(self, value: Any) -> Instruction: return create_load_const(value, checked=False) - def load_method(self, name): + def load_method(self, name: str) -> None: self.tx.output.update_co_names(name) self.append_output(create_load_method(name)) - def call_method(self, nargs): + def call_method(self, nargs: int) -> None: self.extend_output(create_call_method(nargs)) - def create_load_attr(self, name) -> Instruction: + def create_load_attr(self, name: str) -> Instruction: if name not in self.code_options["co_names"]: self.code_options["co_names"] += (name,) return create_instruction("LOAD_ATTR", argval=name) - def load_attr(self, name): + def load_attr(self, name: str) -> None: self.append_output(self.create_load_attr(name)) - def create_load_attrs(self, names): + def create_load_attrs(self, names: str) -> list[Instruction]: return [self.create_load_attr(name) for name in names.split(".")] - def create_store_attr(self, name) -> Instruction: + def create_store_attr(self, name: str) -> Instruction: if name not in self.code_options["co_names"]: self.code_options["co_names"] += (name,) return create_instruction("STORE_ATTR", argval=name) - def store_attr(self, name): + def store_attr(self, name: str) -> None: self.append_output(self.create_store_attr(name)) - def load_function_name(self, fn_name, push_null, num_on_stack=0): + def load_function_name( + self, fn_name: str, push_null: bool, num_on_stack: int = 0 + ) -> list[Instruction]: """Load the global fn_name on the stack num_on_stack down""" output = [] if push_null and sys.version_info >= (3, 11): @@ -495,7 +507,7 @@ def load_function_name(self, fn_name, push_null, num_on_stack=0): ) return output - def rot_n(self, n): + def rot_n(self, n: int) -> list[Instruction]: try: return create_rot_n(n) except AttributeError: @@ -508,29 +520,29 @@ def rot_n(self, n): create_instruction("UNPACK_SEQUENCE", arg=n), ] - def pop_top(self): + def pop_top(self) -> None: self.append_output(create_instruction("POP_TOP")) - def call_function(self, nargs: int, push_null: bool): + def call_function(self, nargs: int, push_null: bool) -> None: self.extend_output(create_call_function(nargs, push_null=push_null)) - def dup_top(self): + def dup_top(self) -> None: self.append_output(create_dup_top()) - def store(self, varname): + def store(self, varname: str) -> None: self.append_output(self.create_store(varname)) - def load_deref(self, varname): + def load_deref(self, varname: str) -> None: self.append_output(self.create_load_deref(varname)) def make_function_with_closure( - self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack=0 - ): + self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack: int = 0 + ) -> None: freevars = code.co_freevars assert freevars output = self._output - def gen_fn(): + def gen_fn() -> None: # Emitting `LOAD_FAST/LOAD_CLOSURE` with names in `co_freevars` # requires that in the generated bytecode, these cells would keep # their original local names, which we ensure via @@ -561,7 +573,7 @@ def gen_fn(): output.extend(self.rot_n(num_on_stack + 1)) self.clear_tos() - def create_load_python_module(self, mod) -> Instruction: + def create_load_python_module(self, mod: types.ModuleType) -> Instruction: """ Generate a LOAD_GLOBAL instruction to fetch a given python module. """ @@ -589,7 +601,7 @@ def make_call_generated_code(self, fn_name: str) -> None: seen_sources: OrderedSet[Source] = OrderedSet() - def collect_temp_source(source): + def collect_temp_source(source: Source) -> None: if source in seen_sources: # This source is used at least twice, so it can be reused self.mark_source_temp(source) @@ -655,10 +667,10 @@ def collect_temp_source(source): self.extend_output(create_call_function(len(graphargs), False)) - def create_import_name(self, module_name) -> Instruction: + def create_import_name(self, module_name: str) -> Instruction: return create_instruction("IMPORT_NAME", argval=module_name) - def load_import_from(self, module_name, object_name) -> None: + def load_import_from(self, module_name: str, object_name: str) -> None: source = AttrSource(self.tx.import_source(module_name), object_name) # Note: This approach is somewhat aggressive because typically, a source is marked # as a tempvar only when it is used more than once. In this case, we're marking it @@ -667,7 +679,9 @@ def load_import_from(self, module_name, object_name) -> None: self.mark_source_temp(source) self(source) - def create_call_function_kw(self, nargs, kw_names, push_null) -> list[Instruction]: + def create_call_function_kw( + self, nargs: int, kw_names: Iterable[str], push_null: bool + ) -> list[Instruction]: if sys.version_info >= (3, 13): output = create_call_function(nargs, push_null) assert output[-1].opname == "CALL" @@ -691,5 +705,5 @@ def create_call_function_kw(self, nargs, kw_names, push_null) -> list[Instructio create_instruction("CALL_FUNCTION_KW", arg=nargs), ] - def create_delete(self, value) -> Instruction: + def create_delete(self, value: object) -> Instruction: return create_instruction("DELETE_FAST", argval=value) diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index e52fb5026cb9..bda2494e7a9f 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -304,11 +304,13 @@ def begin_capture( accumulate_grad: bool, check_nans: bool, ): + global in_compiled_autograd_initial_trace counters["compiled_autograd"]["captures"] += 1 self.id = next(COMPILE_COUNTER) self.aot_id_counter: dict[int, int] = defaultdict(int) self.compile_context = make_compile_context(self.id) self.compile_context.__enter__() + in_compiled_autograd_initial_trace = True self.nan_checker = NaNChecker(accumulate_grad) if check_nans else None self.start_time_ns = time.time_ns() get_chromium_event_logger().log_event_start( @@ -969,6 +971,8 @@ def create_graph_module(self, id): return GraphModule(self.fx_tracer.root, self.fx_tracer.graph, id) def end_capture(self, outputs): + global in_compiled_autograd_initial_trace + self.fx_tracer.create_proxy( "call_function", FakeCompiledAutogradEngine._exec_final_callbacks_stub, @@ -1085,6 +1089,7 @@ def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks, packed_inputs): log_pt2_compile_event=True, ) self.compile_context.__exit__(None, None, None) + in_compiled_autograd_initial_trace = False return runtime_wrapper, self.compiler_fn(graph) @staticmethod @@ -1394,6 +1399,9 @@ def set_node_origin( # global flag to check if compiled autograd is enabled but Dynamo stance is "force_eager" compiled_autograd_enabled_force_eager = False +# global flag to check if we are capturing for compiled autograd +in_compiled_autograd_initial_trace = False + # global flag to check if we are processing graphs produced from a compiled autograd graph in_compiled_autograd_region = False @@ -1498,12 +1506,13 @@ def _disable(): # return to starting state of a new process def reset() -> None: - global compiled_autograd_enabled + global compiled_autograd_enabled, in_compiled_autograd_initial_trace compiled_autograd_enabled = False assert not in_compiled_autograd_region torch._C._dynamo.compiled_autograd.set_autograd_compiler(None, False) torch._C._dynamo.compiled_autograd.set_verbose_logger(None) torch._C._dynamo.compiled_autograd.clear_cache() + in_compiled_autograd_initial_trace = False global COMPILE_COUNTER COMPILE_COUNTER = itertools.count() diff --git a/torch/_dynamo/comptime.py b/torch/_dynamo/comptime.py index e21855563efd..2864168dfb82 100644 --- a/torch/_dynamo/comptime.py +++ b/torch/_dynamo/comptime.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ This module provides the public comptime interface to TorchDynamo, enabling users to execute arbitrary Python code during symbolic evaluation of their programs. @@ -40,9 +38,13 @@ def my_model(x): import dis import time import traceback -from typing import Optional, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, TextIO, Union import torch +from torch._dynamo.symbolic_convert import InstructionTranslatorBase +from torch._dynamo.variables.base import VariableTracker +from torch._subclasses.fake_tensor import FakeTensor from torch.fx.experimental.symbolic_shapes import free_symbols from .exc import unimplemented_v2 @@ -62,10 +64,10 @@ class ComptimeVar: actual data in the Tensor is.) """ - def __init__(self, v) -> None: + def __init__(self, v: VariableTracker) -> None: self.__variable = v - def as_proxy(self): + def as_proxy(self) -> Union[VariableTracker, Sequence[VariableTracker]]: """ Returns an fx.Proxy (or tuple/list of fx.Proxy) representing this variable in the FX graph we are assembling to pass @@ -79,13 +81,13 @@ def as_proxy(self): """ return self.__variable.as_proxy() - def is_proxy(self): + def is_proxy(self) -> bool: """ Returns True if as_proxy() would succeed. """ return self.__variable.is_proxy() - def as_fake(self): + def as_fake(self) -> Union[FakeTensor, torch.SymInt]: """ Returns a "fake" value (either a FakeTensor or a SymInt) representing the variable in question. This only works @@ -102,16 +104,16 @@ def size(self, dim: Optional[int] = None) -> Union[int, torch.SymInt]: Returns the size of the tensor (if dim is None) or the size at the dimension dim. The returned size may be a SymInt. """ - return self.as_fake().size(dim) + return self.as_fake().size(dim) # type: ignore[union-attr, return-value] - def python_type(self): + def python_type(self) -> type: """ Returns what type(v) would have returned for the variable at compile time. """ return self.__variable.python_type() - def as_python_constant(self): + def as_python_constant(self) -> Any: """ Returns the Python value this variable would have, but only if it is completely known at compile-time (e.g., it is constant). @@ -123,19 +125,19 @@ def as_python_constant(self): """ return self.__variable.as_python_constant() - def is_python_constant(self): + def is_python_constant(self) -> bool: """ Returns True if as_python_constant would succeed. """ return self.__variable.is_python_constant() - def is_dynamic(self): + def is_dynamic(self) -> bool: if isinstance(self.__variable, SymNodeVariable): fs = free_symbols(self.__variable.sym_num) return bool(fs) return False - def force_static(self): + def force_static(self) -> None: """ Forces that a value is static, inducing a guard on its specific value """ @@ -149,7 +151,7 @@ def force_static(self): f"cannot force {self.__variable} ({type(self.__variable)}) static" ) - def _i_will_not_complain_if_bc_breaks_VariableTracker(self): + def _i_will_not_complain_if_bc_breaks_VariableTracker(self) -> VariableTracker: """ Returns the internal data structure VariableTracker that Dynamo uses to represent variables at compile time. There are no BC guarantees on @@ -171,10 +173,10 @@ class ComptimeContext: file a feature request at https://github.com/pytorch/pytorch/ """ - def __init__(self, tx) -> None: + def __init__(self, tx: InstructionTranslatorBase) -> None: self.__tx = tx - def get_local(self, name: str, *, stacklevel=0) -> ComptimeVar: + def get_local(self, name: str, *, stacklevel: int = 0) -> ComptimeVar: """ Retrieve the compile-time known information about a local. """ @@ -187,7 +189,7 @@ def get_local(self, name: str, *, stacklevel=0) -> ComptimeVar: return ComptimeVar(var) - def graph_break(self, msg="ComptimeContext.graph_break"): + def graph_break(self, msg: str = "ComptimeContext.graph_break") -> None: """ Manually trigger a graph break """ @@ -198,14 +200,14 @@ def graph_break(self, msg="ComptimeContext.graph_break"): hints=[], ) - def graph(self): + def graph(self) -> torch.fx.Graph: """ Retrieve the partially constructed FX graph that would be passed to the user compiler after compilation. """ return self.__tx.output.graph - def assert_static(self, val): + def assert_static(self, val: ComptimeVar) -> None: """ Asserts that the int is static (and not dynamic, per dynamic shapes) """ @@ -213,7 +215,9 @@ def assert_static(self, val): "expected static but got dynamic (run with TORCH_LOGS=dynamic for more info)" ) - def print_graph(self, *, verbose=True, file=None): + def print_graph( + self, *, verbose: bool = True, file: Optional[TextIO] = None + ) -> None: """ Print the partially constructed FX graph that would be passed to the user compiler after compilation. @@ -222,19 +226,21 @@ def print_graph(self, *, verbose=True, file=None): self.__tx.output.graph.python_code("self", verbose=verbose).src, file=file ) - def parent(self): - return ComptimeContext(self.__tx.parent) + def parent(self) -> "ComptimeContext": + return ComptimeContext(self.__tx.parent) # type: ignore[arg-type] - def __get_tx(self, stacklevel): + def __get_tx(self, stacklevel: int) -> Any: tx = self.__tx for _ in range(stacklevel): - tx = tx.parent + tx = tx.parent # type: ignore[assignment] return tx - def print(self, val, *, file=None): + def print(self, val: Any, *, file: Optional[TextIO] = None) -> None: print(repr(val), file=file) - def print_disas(self, *, file=None, stacklevel=0): + def print_disas( + self, *, file: Optional[TextIO] = None, stacklevel: int = 0 + ) -> None: """ Print the current series of opcodes being executed (not including parent frames), including where you are in the particular opcode @@ -249,7 +255,9 @@ def print_disas(self, *, file=None, stacklevel=0): file=file, ) - def print_value_stack(self, *, file=None, stacklevel=0): + def print_value_stack( + self, *, file: Optional[TextIO] = None, stacklevel: int = 0 + ) -> None: """ Print the current Python value stack. Note that this is NOT the same as the traceback; use print_bt() to print that. Note that at @@ -264,7 +272,9 @@ def print_value_stack(self, *, file=None, stacklevel=0): for s in tx.stack: print(f"- {s.debug_repr()}", file=file) - def print_locals(self, *, file=None, stacklevel=0): + def print_locals( + self, *, file: Optional[TextIO] = None, stacklevel: int = 0 + ) -> None: """ Print all of the locals available in the current context. By default this view is very limited; you can get more information @@ -274,7 +284,7 @@ def print_locals(self, *, file=None, stacklevel=0): for k, v in tx.symbolic_locals.items(): print(f"{k} = {v.debug_repr()}", file=file) - def print_bt(self, *, file=None, stacklevel=0): + def print_bt(self, *, file: Optional[TextIO] = None, stacklevel: int = 0) -> None: """ Print the user code backtrace, starting at the beginning of the frame Dynamo started evaluating. Note that this MAY NOT go all @@ -293,7 +303,7 @@ def print_bt(self, *, file=None, stacklevel=0): file=file, ) - def print_guards(self, *, file=None): + def print_guards(self, *, file: Optional[TextIO] = None) -> None: """ Print the currently installed guards for the Dynamo context. This does NOT include guards associated with variables that @@ -307,7 +317,9 @@ def print_guards(self, *, file=None): file=file, ) - def _i_will_not_complain_if_bc_breaks_InstructionTranslator(self): + def _i_will_not_complain_if_bc_breaks_InstructionTranslator( + self, + ) -> InstructionTranslatorBase: """ Returns the internal data structure InstructionTranslator that Dynamo uses to track state of symbolic evaluation. There are no BC @@ -316,32 +328,35 @@ def _i_will_not_complain_if_bc_breaks_InstructionTranslator(self): """ return self.__tx - def sleep(self, sec): + def sleep(self, sec: Union[int, float]) -> None: time.sleep(sec) class _Comptime: @staticmethod - def __call__(fn, fallback_fn=lambda: None): + def __call__( + fn: Callable[[ComptimeContext], Any], + fallback_fn: Callable[[], Any] = lambda: None, + ) -> Any: """fn gets called at compile time in TorchDynamo, calls fallback_fn otherwise""" fallback_fn() # Convenience wrappers that are more compact to use @staticmethod - def graph_break(): + def graph_break() -> None: comptime(lambda ctx: ctx.graph_break()) @staticmethod - def print(e): + def print(e: Any) -> None: comptime(lambda ctx: ctx.print(ctx.get_local("e")), lambda: print(e)) @staticmethod - def print_graph(): + def print_graph() -> None: comptime(lambda ctx: ctx.print_graph()) @staticmethod - def print_disas(*, stacklevel=0): + def print_disas(*, stacklevel: int = 0) -> None: comptime( lambda ctx: ctx.print_disas( stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 @@ -349,7 +364,7 @@ def print_disas(*, stacklevel=0): ) @staticmethod - def print_value_stack(*, stacklevel=0): + def print_value_stack(*, stacklevel: int = 0) -> None: comptime( lambda ctx: ctx.print_value_stack( stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 @@ -360,7 +375,7 @@ def print_value_stack(*, stacklevel=0): # in an expression context; e.g., x + print_value_stack_and_return(y + z), # you will see x on the stack prior to the addition operation @staticmethod - def print_value_stack_and_return(e, *, stacklevel=0): + def print_value_stack_and_return(e: Any, *, stacklevel: int = 0) -> Any: comptime( lambda ctx: ctx.print_value_stack( stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 @@ -369,7 +384,7 @@ def print_value_stack_and_return(e, *, stacklevel=0): return e @staticmethod - def print_locals(*, stacklevel=0): + def print_locals(*, stacklevel: int = 0) -> None: comptime( lambda ctx: ctx.print_locals( stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 @@ -377,7 +392,7 @@ def print_locals(*, stacklevel=0): ) @staticmethod - def print_bt(*, stacklevel=0): + def print_bt(*, stacklevel: int = 0) -> None: comptime( lambda ctx: ctx.print_bt( stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1 @@ -385,19 +400,19 @@ def print_bt(*, stacklevel=0): ) @staticmethod - def print_guards(): + def print_guards() -> None: comptime(lambda ctx: ctx.print_guards()) @staticmethod - def assert_static(val): + def assert_static(val: Any) -> None: comptime(lambda ctx: ctx.assert_static(ctx.get_local("val"))) @staticmethod - def force_static(val): + def force_static(val: Any) -> None: comptime(lambda ctx: ctx.get_local("val").force_static()) @staticmethod - def breakpoint(): + def breakpoint() -> None: """ Like pdb breakpoint(), but drop into pdb whenever this line of code is compiled by dynamo. Use it by putting @@ -415,14 +430,14 @@ def breakpoint(): (Pdb) p ctx.get_local("attention").as_fake() """ - def inner(inner_ctx): + def inner(inner_ctx: ComptimeContext) -> None: ctx = inner_ctx.parent() # noqa: F841 builtins.breakpoint() comptime(inner) @staticmethod - def sleep(sec): + def sleep(sec: Union[int, float]) -> None: comptime(lambda ctx: ctx.sleep(ctx.get_local("sec").as_python_constant())) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 21598f71bced..7ef748b85f3e 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Configuration module for TorchDynamo compiler and optimization settings. @@ -284,6 +282,13 @@ # Defaults to False for BC. allow_unspec_int_on_nn_module = False +# Mirrors `allow_unspec_int_on_nn_module`, but for FSDP: for <=2.8 versions, +# integer attributes on FSDP modules were treated as dynamic, while the same +# attributes on plain nn.Modules were static. We unified the behaviour by making +# FSDP ints static too. Set this flag to True to restore the legacy dynamic +# handling if needed. +allow_unspec_int_on_fsdp_module = False + # Specify how to optimize a compiled DDP module. The flag accepts a boolean # value or a string. There are 3 modes. # 1. "ddp_optimizer" (or True): with "ddp_optimizer", Dynamo will automatically @@ -450,7 +455,7 @@ record_compile_time_instruction_count = False -def default_debug_dir_root(): +def default_debug_dir_root() -> str: # [@compile_ignored: debug] DEBUG_DIR_VAR_NAME = "TORCH_COMPILE_DEBUG_DIR" if DEBUG_DIR_VAR_NAME in os.environ: @@ -608,6 +613,9 @@ def default_debug_dir_root(): os.environ.get("UNSAFE_SKIP_FSDP_MODULE_GUARDS", "0") == "1" ) +# Common prefix to append to the id of each compile run to filter out data +pt2_compile_id_prefix: Optional[str] = os.environ.get("PT2_COMPILE_ID_PREFIX", None) + # Run GC at the end of compilation run_gc_after_compile = Config( # type: ignore[var-annotated] default=True, @@ -629,7 +637,7 @@ def default_debug_dir_root(): if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 - def _make_closure_patcher(**changes): ... + def _make_closure_patcher(**changes: Any) -> Any: ... install_config_module(sys.modules[__name__]) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index fe547691add6..149a1c400d99 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-decorators - """ This module implements TorchDynamo's core frame conversion functionality, transforming Python frames into FX graphs. It handles: @@ -495,6 +493,29 @@ def _is_error_on_graph_break(tx: Optional[InstructionTranslator]) -> bool: return tx.error_on_graph_break +def get_compile_id( + frame_state: dict[str, Union[int, FrameStateSizeEntry]], +) -> CompileId: + global FRAME_COUNTER + if "_id" not in frame_state: + frame_state["_id"] = FRAME_COUNTER + FRAME_COUNTER += 1 + frame_id = frame_state["_id"] + assert isinstance(frame_id, int) + + frame_compile_id = FRAME_COMPILE_COUNTER[frame_id] + FRAME_COMPILE_COUNTER[frame_id] += 1 + + compiled_autograd_id = None + if prior := CompileContext.current_compile_id(): + compiled_autograd_id = prior.compiled_autograd_id + return CompileId( + compiled_autograd_id=compiled_autograd_id, + frame_id=frame_id, + frame_compile_id=frame_compile_id, + ) + + class ConvertFrameAssert: def __init__( self, @@ -610,24 +631,8 @@ def __call__( global initial_global_state initial_global_state = GlobalStateGuard() - global FRAME_COUNTER - if "_id" not in frame_state: - frame_state["_id"] = FRAME_COUNTER - FRAME_COUNTER += 1 - frame_id = frame_state["_id"] - assert isinstance(frame_id, int) - - frame_compile_id = FRAME_COMPILE_COUNTER[frame_id] - FRAME_COMPILE_COUNTER[frame_id] += 1 - - compiled_autograd_id = None - if prior := CompileContext.current_compile_id(): - compiled_autograd_id = prior.compiled_autograd_id - compile_id = CompileId( - compiled_autograd_id=compiled_autograd_id, - frame_id=frame_id, - frame_compile_id=frame_compile_id, - ) + compile_id = get_compile_id(frame_state) + frame_id = compile_id.frame_id signpost_event( "dynamo", diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index a23b58cedf22..e084222c2171 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -1,6 +1,3 @@ -# mypy: allow-untyped-defs -# mypy: disable-error-code="method-assign" - """ Debug utilities for TorchDynamo compilation and execution. @@ -34,6 +31,7 @@ import tempfile import textwrap from collections import Counter +from collections.abc import Sequence from importlib import import_module from typing import Any, Callable, Optional, TypeVar @@ -43,7 +41,9 @@ from torch import Tensor from torch._dynamo.testing import rand_strided from torch._prims_common import is_float_dtype +from torch.hub import tqdm from torch.multiprocessing.reductions import StorageWeakRef +from torch.storage import UntypedStorage from torch.utils._content_store import ContentStoreReader, ContentStoreWriter from . import config @@ -64,6 +64,7 @@ extra_deps = [] extra_imports = "" +cur_target = "" if use_buck: extra_deps = [ "//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu", @@ -79,7 +80,7 @@ class BuckTargetWriter: - def __init__(self, filename): + def __init__(self, filename: str) -> None: self.subdir, self.py_file = os.path.split(os.path.abspath(filename)) self.target = self.py_file.replace(".py", "") @@ -93,7 +94,7 @@ def __init__(self, filename): tmp = tmp[tmp.find("fbcode/") :][7:] self.cmd_line_path = f"//{tmp}:{self.target}" - def build(self): + def build(self) -> str: extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps]) return textwrap.dedent( f""" @@ -119,7 +120,7 @@ def build(self): """ ) - def write(self, print_msg=True): + def write(self, print_msg: bool = True) -> list[str]: target_file = os.path.join(self.subdir, "TARGETS") with open(target_file, "w") as fd: fd.write(self.build()) @@ -133,7 +134,7 @@ def write(self, print_msg=True): return cmd_split -def minifier_dir(): +def minifier_dir() -> str: path = os.path.join(get_debug_dir(), "minifier") if path is None: path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}" @@ -171,7 +172,7 @@ class NNModuleToString: ] @staticmethod - def can_convert_to_string(gm): + def can_convert_to_string(gm: torch.fx.GraphModule) -> bool: cant_convert = set() for _, module in gm.named_children(): if type(module) not in NNModuleToString.safe_reprs: @@ -183,7 +184,7 @@ def can_convert_to_string(gm): return True @staticmethod - def convert(gm): + def convert(gm: torch.fx.GraphModule) -> str: from torch.nn.modules.module import _addindent tab = " " * 4 @@ -248,7 +249,7 @@ def __init__(self) -> None: @functools.cache # subprocess is expensive -def _cuda_system_info_comment(): +def _cuda_system_info_comment() -> str: if not torch.cuda.is_available(): return "# torch.cuda.is_available()==False, no GPU info collected\n" @@ -272,7 +273,7 @@ def _cuda_system_info_comment(): return model_str -def generate_env_vars_string(*, stable_output=False): +def generate_env_vars_string(*, stable_output: bool = False) -> str: """ Generate a string configuration for environment variables related to Dynamo, Inductor, and Triton. """ @@ -282,7 +283,7 @@ def generate_env_vars_string(*, stable_output=False): allow_list = ["TORCH", "DYNAMO", "INDUCTOR", "TRITON"] skip_list = ["TRITON_LIBDEVICE_PATH", "TRITON_PTXAS_PATH", "TRITON_LIBCUDA_PATH"] - def filter(key): + def filter(key: str) -> bool: return any(string in key for string in allow_list) and key not in skip_list config_lines = [ @@ -297,7 +298,7 @@ def filter(key): """ -def generate_config_string(*, stable_output=False): +def generate_config_string(*, stable_output: bool = False) -> str: import torch._functorch.config import torch._inductor.config @@ -317,11 +318,11 @@ def generate_config_string(*, stable_output=False): """ -def get_minifier_repro_path(): +def get_minifier_repro_path() -> str: return os.path.join(minifier_dir(), "minifier_launcher.py") -def helper_for_dump_minify(contents): +def helper_for_dump_minify(contents: str) -> None: minified_repro_path = get_minifier_repro_path() log.warning("Writing minified repro to:\n%s", minified_repro_path) @@ -340,7 +341,7 @@ class AccuracyError(Exception): pass -def clone_inputs_retaining_gradness(example_inputs): +def clone_inputs_retaining_gradness(example_inputs: Sequence[Any]) -> list[Any]: """ This clone inputs is different from utils clone_input. In case of minifier, all the tensors are leaf tensors while creating a new graph. So, we set the @@ -350,10 +351,15 @@ def clone_inputs_retaining_gradness(example_inputs): for idx in range(len(example_inputs)): if isinstance(cloned_inputs[idx], torch.Tensor): cloned_inputs[idx].requires_grad_(example_inputs[idx].requires_grad) - return cloned_inputs + return cloned_inputs # type: ignore[return-value] -def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False): +def run_fwd_maybe_bwd( + gm: torch.fx.GraphModule, + args: Sequence[Any], + only_fwd: bool = False, + disable_clone: bool = False, +) -> Any: """ Runs a forward and possibly backward iteration for a given mod and args. @@ -381,14 +387,14 @@ def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False): def same_two_models( - gm, - opt_gm, - example_inputs, - only_fwd=False, + gm: torch.fx.GraphModule, + opt_gm: torch.fx.GraphModule, + example_inputs: Sequence[Any], + only_fwd: bool = False, *, - require_fp64=False, - ignore_non_fp=False, -): + require_fp64: bool = False, + ignore_non_fp: bool = False, +) -> bool: """ Check two models have same accuracy. @@ -438,7 +444,7 @@ def same_two_models( return passing -def cast_dtype_args_to_fp64(model): +def cast_dtype_args_to_fp64(model: torch.fx.GraphModule) -> torch.fx.GraphModule: for node in model.graph.nodes: if ( node.op == "call_function" @@ -459,7 +465,9 @@ def cast_dtype_args_to_fp64(model): return model -def cast_to(dtype, model, inputs): +def cast_to( + dtype: torch.dtype, model: torch.fx.GraphModule, inputs: list[Any] +) -> tuple[torch.fx.GraphModule, list[Any]]: from torch.utils._pytree import tree_map model = model.to(dtype) @@ -477,19 +485,21 @@ def cast_to(dtype, model, inputs): return model, inputs -def cast_to_fp64(model, inputs): +def cast_to_fp64( + model: torch.fx.GraphModule, inputs: list[Any] +) -> tuple[torch.fx.GraphModule, list[Any]]: return cast_to(torch.float64, model, inputs) def backend_accuracy_fails( - gm, - example_inputs, - compiler_fn, - only_fwd=False, + gm: torch.fx.GraphModule, + example_inputs: Sequence[Any], + compiler_fn: Callable[[torch.fx.GraphModule, list[Any]], torch.fx.GraphModule], + only_fwd: bool = False, *, - require_fp64=False, - ignore_non_fp=False, -): + require_fp64: bool = False, + ignore_non_fp: bool = False, +) -> bool: try: compiled_gm = compiler_fn( copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs) @@ -545,20 +555,27 @@ class NopInputReader: def __init__(self) -> None: self.total = 0 - def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): + def storage( + self, + storage_hash: Optional[str], + nbytes: int, + *, + device: Optional["torch._prims_common.DeviceLikeType"] = None, + dtype_hint: Optional[torch.dtype] = None, + ) -> None: self.total += 1 - def tensor(self, *args, **kwargs): + def tensor(self, *args: Any, **kwargs: Any) -> Optional[torch.Tensor]: pass - def symint(self, *args, **kwargs): + def symint(self, *args: Any, **kwargs: Any) -> Optional[int]: pass # TODO: Support bundling the entire repro into a zip file for ease of # transferring around class InputReader: - def __init__(self, save_dir=None, *, pbar=None): + def __init__(self, save_dir: Optional[str] = None, *, pbar: Optional[tqdm] = None): # If None, we will generate random data instead. It's important # to natively support this use case as it will allow people to # share repros without including the real data, if the problem @@ -566,13 +583,20 @@ def __init__(self, save_dir=None, *, pbar=None): if save_dir is None: log.warning("no save_dir specified, will generate random data") self.store = ContentStoreReader(save_dir) if save_dir is not None else None - self.args = [] + self.args: list[Any] = [] self.pbar = pbar - def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): + def storage( + self, + storage_hash: Optional[str], + nbytes: int, + *, + device: Optional["torch._prims_common.DeviceLikeType"] = None, + dtype_hint: Optional[torch.dtype] = None, + ) -> UntypedStorage: if self.pbar is not None: self.pbar.update(1) - device = _device_or_default(device) + device = _device_or_default(device) # type: ignore[arg-type] dtype_hint = _dtype_or_default(dtype_hint) if self.store is not None and storage_hash is not None: try: @@ -593,16 +617,16 @@ def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): def tensor( self, - storage, - shape, - stride=None, + storage: UntypedStorage, + shape: "torch._prims_common.ShapeType", + stride: Optional["torch._prims_common.StrideType"] = None, *, - storage_offset=None, - dtype=None, - requires_grad=None, - is_leaf=None, - **metadata, - ): + storage_offset: Optional[int] = None, + dtype: Optional[torch.dtype] = None, + requires_grad: Optional[bool] = None, + is_leaf: Optional[bool] = None, + **metadata: Any, + ) -> torch.Tensor: stride = _stride_or_default(stride, shape=shape) storage_offset = _storage_offset_or_default(storage_offset) dtype = _dtype_or_default(dtype) @@ -624,7 +648,7 @@ def tensor( self.args.append(t) return t # for BC - def symint(self, val): + def symint(self, val: Any) -> Any: self.args.append(val) return val # for BC @@ -642,8 +666,8 @@ def symint(self, val): class InputWriter: - def __init__(self, save_dir, *, stable_hash=False): - self._lines = [] + def __init__(self, save_dir: Optional[str], *, stable_hash: bool = False) -> None: + self._lines: list[str] = [] # TODO: consider ensuring tensor and storage counters line up? self.storage_counter = itertools.count() self.save_dir = save_dir @@ -652,9 +676,9 @@ def __init__(self, save_dir, *, stable_hash=False): if save_dir is not None else None ) - self.seen_storages = {} + self.seen_storages: dict[StorageWeakRef, str] = {} - def lines(self): + def lines(self) -> list[str]: r = [ "def load_args(reader):", ] @@ -669,7 +693,13 @@ def lines(self): # of initialization may be appropriate # # If we had a FakeTensor, device_hint tells us what device should be - def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str: + def storage( + self, + untyped_storage: UntypedStorage, + *, + device_hint: Optional["torch._prims_common.DeviceLikeType"] = None, + dtype_hint: Optional[torch.dtype] = None, + ) -> str: ws = StorageWeakRef(untyped_storage) v = self.seen_storages.get(ws) if v is not None: @@ -684,7 +714,7 @@ def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str: device = untyped_storage.device if device.type == "meta": assert device_hint is not None - device = device_hint + device = device_hint # type: ignore[assignment] if _device_or_default(None) != device: maybe_device = f", device={device!r}" nbytes = untyped_storage.nbytes() @@ -697,7 +727,7 @@ def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str: self.seen_storages[ws] = v return v - def tensor(self, name, t) -> None: + def tensor(self, name: str, t: torch.Tensor) -> None: from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq storage = self.storage( @@ -729,7 +759,7 @@ def tensor(self, name, t) -> None: + f") # {name}" ) - def unsupported(self, name, arg): + def unsupported(self, name: str, arg: Any) -> None: # NB: Try hard not to /print/ a tensor, that will be very slow self._lines.append(f"# {name} was unsupported type for dumping: {type(arg)}") # Best effort dump as much useful stuff we can lol, in case you want @@ -747,13 +777,13 @@ def unsupported(self, name, arg): self._lines.append('"""') # write out that the arg was filtered out as it is constant - def const(self, name) -> None: + def const(self, name: str) -> None: self._lines.append( f"reader.const({name!r}) # {name}, filtered out during compilation" ) # TODO: this doesn't actually symint atm - def symint(self, name, val) -> None: + def symint(self, name: str, val: Any) -> None: if isinstance(val, torch.SymInt): val = val.node.hint self._lines.append(f"reader.symint({val!r}) # {name}") @@ -782,8 +812,10 @@ def forward(self, primals_1: "f32[1001, 6]", primals_2: "f32[s0]", primals_3: "S from torch.utils._dtype_abbrs import dtype_abbrs - dtype_map = {value: key for key, value in dtype_abbrs.items()} - dtype_pattern = "|".join(dtype_abbrs.values()) + dtype_map: dict[str, torch.dtype] = { + value: key for key, value in dtype_abbrs.items() + } + dtype_pattern: str = "|".join(dtype_abbrs.values()) # Extracting the source code from the function source = inspect.getsource(func) @@ -799,21 +831,23 @@ class TensorContainer: # Dictionary for tensors from annotations kwargs: dict[str, Any] = {} - sym_shapes = sym_shapes or {} + sym_shapes_dict: dict[str, int] = sym_shapes or {} - def get_sym_int(symint): + def get_sym_int(symint: str) -> int: torch._check( - symint in sym_shapes or default_sym_shape is not None, + symint in sym_shapes_dict or default_sym_shape is not None, lambda: f"{symint} not in symbolic_shapes and default sym shape not passed in", ) - return sym_shapes.get(symint, default_sym_shape) + return sym_shapes_dict.get(symint, default_sym_shape) # type: ignore[return-value] - def gen_tensor(shape, dtype) -> Tensor: + def gen_tensor( + shape: "torch._prims_common.ShapeType", dtype: torch.dtype + ) -> Tensor: # Resolve symbolic shapes to concrete values resolved_shape = [] dynamic_dims = [] for i, dim in enumerate(shape): - dim = dim.strip() + dim = dim.strip() # type: ignore[attr-defined] if "s" in dim: s = get_sym_int(dim) resolved_shape.append(s) @@ -868,9 +902,9 @@ def profile_to_file(filename: str) -> Callable[[T], T]: prof = cProfile.Profile() filename = os.path.abspath(os.path.expanduser(filename)) - def decorator(fn): + def decorator(fn: Any) -> Any: @functools.wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: prof.enable() try: return fn(*args, **kwargs) @@ -879,7 +913,7 @@ def wrapper(*args, **kwargs): return wrapper - def save_it(): + def save_it() -> None: prof.dump_stats(filename) sys.stderr.write( textwrap.dedent( diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index d49f3c435e56..13b61b7fa3e3 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -1,6 +1,3 @@ -# mypy: allow-untyped-defs -# ruff: noqa: TCH004 - """ This module provides decorators and utilities for controlling TorchDynamo's behavior during compilation. """ @@ -9,10 +6,12 @@ import inspect import weakref from dataclasses import dataclass +from types import TracebackType from typing import Any, Callable, Optional, overload, TYPE_CHECKING, TypeVar, Union from typing_extensions import ParamSpec import torch +from torch.compiler import is_compiling from torch.utils._contextlib import _DecoratorContextManager from torch.utils._python_dispatch import is_traceable_wrapper_subclass @@ -29,7 +28,6 @@ from .exc import IncorrectUsage from .external_utils import ( get_nonrecursive_disable_wrapper, - is_compiling, wrap_dunder_call_ctx_manager, ) from .utils import _get_error_on_graph_break, _set_error_on_graph_break, is_function @@ -56,9 +54,11 @@ _P = ParamSpec("_P") _R = TypeVar("_R") +FuncType = Callable[..., Any] +F = TypeVar("F", bound=FuncType) -def run(fn=None): +def run(fn: Optional[Callable[_P, _R]] = None) -> Any: """Don't do any dynamic compiles, just use prior optimizations""" if fn is not None: fn = innermost_fn(fn) @@ -67,7 +67,7 @@ def run(fn=None): return RunOnlyContext() -def disable(fn=None, recursive=True, *, reason=None, wrapping=True): +def disable(fn=None, recursive=True, *, reason=None, wrapping=True): # type: ignore[no-untyped-def] """ Decorator to disable TorchDynamo @@ -87,7 +87,7 @@ def disable(fn=None, recursive=True, *, reason=None, wrapping=True): return DisableContext(msg=reason, wrapping=wrapping) else: - def wrap(fn): + def wrap(fn: Callable[_P, _R]) -> Callable[_P, _R]: fn = innermost_fn(fn) assert callable(fn) @@ -106,7 +106,7 @@ def wrap(fn): skip_code(_nonrecursive_disable_wrapper_code) -def skip(fn=None): +def skip(fn: Optional[Callable[_P, _R]] = None) -> Callable[..., Any]: """ Skip frames associated with the function code, but still process recursively invoked frames @@ -116,7 +116,7 @@ def skip(fn=None): fn = innermost_fn(fn) assert callable(fn) skip_code(fn.__code__) - fn._torchdynamo_disable = True + fn._torchdynamo_disable = True # type: ignore[attr-defined] return fn @@ -134,7 +134,7 @@ def __init__( stance: str = "default", *, skip_guard_eval_unsafe: bool = False, - force_backend=None, + force_backend: Union[str, Callable[..., Any], None] = None, ) -> None: if force_backend is not None and stance != "default": raise RuntimeError("non-default stance cannot have force_backend set") @@ -142,29 +142,34 @@ def __init__( self.stance = DynamoStance(stance, skip_guard_eval_unsafe, force_backend) self.prev = _set_stance(self.stance) - def __call__(self, fn): + def __call__(self, fn: F) -> F: _set_stance(self.prev) wrapper = super().__call__(fn) # forbid wrapper in graph wrapper._dynamo_forbidden = True # type: ignore[attr-defined] return wrapper - def __enter__(self): + def __enter__(self) -> None: _set_stance(self.stance) - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: _set_stance(self.prev) - def clone(self): + def clone(self) -> "set_stance": return self.__class__(self.stance.stance, force_backend=self.stance.backend) -def assume_constant_result(fn): - fn._dynamo_marked_constant = True +def assume_constant_result(fn): # type: ignore[no-untyped-def] + fn._dynamo_marked_constant = True # type: ignore[attr-defined] return fn -def allow_in_graph(fn): +def allow_in_graph(fn): # type: ignore[no-untyped-def] """ Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function and instead directly write it to the graph when encountered. @@ -182,14 +187,14 @@ def allow_in_graph(fn): trace_rules._allowed_callable_ids.add(fn_id) # Avoid id reuse which creates subtle bugs. - def deregister(): + def deregister() -> None: trace_rules._allowed_callable_ids.remove(fn_id) weakref.finalize(fn, deregister) return fn -def nonstrict_trace(traceable_fn): +def nonstrict_trace(traceable_fn: Callable[_P, _R]) -> Callable[_P, _R]: # Like `allow_in_graph`, but with the following enhancements/differences: # # 1. Supports user-defined class as inputs, as long as the class has been @@ -210,7 +215,7 @@ def nonstrict_trace(traceable_fn): assert callable(traceable_fn), "nonstrict_trace expects a callable" @functools.wraps(traceable_fn) - def wrapped(*args, **kwargs): + def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _R: return traceable_fn(*args, **kwargs) wrapped_id = id(wrapped) @@ -222,7 +227,7 @@ def wrapped(*args, **kwargs): trace_rules._nonstrict_trace_callable_ids.add(wrapped_id) # Avoid id reuse which creates subtle bugs. - def deregister(): + def deregister() -> None: trace_rules._allowed_callable_ids.remove(wrapped_id) trace_rules._nonstrict_trace_callable_ids.remove(wrapped_id) @@ -231,8 +236,8 @@ def deregister(): return wrapped -def _disallow_in_graph_helper(throw_if_not_allowed): - def inner(fn): +def _disallow_in_graph_helper(throw_if_not_allowed: bool) -> Callable[..., Any]: + def inner(fn: Any) -> Any: if isinstance(fn, (list, tuple)): return [disallow_in_graph(x) for x in fn] assert callable(fn), "disallow_in_graph expects a callable" @@ -254,7 +259,7 @@ def inner(fn): return inner -def disallow_in_graph(fn): +def disallow_in_graph(fn: Callable[..., Any]) -> Any: """ Customize which functions TorchDynamo will exclude in the generated graph and force a graph break on. @@ -280,17 +285,17 @@ def fn(a): @_disallow_in_graph_helper(throw_if_not_allowed=False) -def graph_break(msg=""): +def graph_break(msg: str = "") -> None: """Force a graph break""" # NOTE: primarily used for internal debugging purposes! @_disallow_in_graph_helper(throw_if_not_allowed=False) -def skip_frame(msg=""): +def skip_frame(msg: str = "") -> None: """Force a skipped frame""" -def forbid_in_graph(fn): +def forbid_in_graph(fn: Any) -> Any: """ Customize which functions TorchDynamo will assert are not present while tracing. @@ -392,7 +397,9 @@ def wrapper(traceable_fn: Callable[_P, _R]) -> Callable[_P, _R]: else: traceable_sig = inspect.signature(traceable_fn) - def sig_ident(sig): + def sig_ident( + sig: inspect.Signature, + ) -> tuple[tuple[str, ...], set[str], dict[str, Any]]: # Ignore annotations for parameters and return type return ( tuple( @@ -472,7 +479,9 @@ def sig_ident(sig): def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _R: return original_fn(*args, **kwargs) - def dispatch_fn(self, value: Callable[_P, _R]) -> PolyfilledFunctionVariable: + def dispatch_fn( + self: VariableBuilder, value: Callable[_P, _R] + ) -> PolyfilledFunctionVariable: return PolyfilledFunctionVariable( value, source=self.source, @@ -497,7 +506,9 @@ def dispatch_fn(self, value: Callable[_P, _R]) -> PolyfilledFunctionVariable: # Helper function to flatten a tensor subclass and apply a function to # all inner tensors that match the outer dim. Used to reduce duplication # across the various marking APIs. -def _apply_func_to_inner_tensors_of_same_dim(func, t, *args, **kwargs): +def _apply_func_to_inner_tensors_of_same_dim( + func: Callable[..., Any], t: object, *args: Any, **kwargs: Any +) -> None: assert is_traceable_wrapper_subclass(t) attrs, _ctx = t.__tensor_flatten__() @@ -522,7 +533,12 @@ class directly; instead, use :func:`mark_dynamic`. @forbid_in_graph -def mark_unbacked(t, index, strict=False, specialize_on=None): +def mark_unbacked( + t: Any, + index: Union[int, list[Any], tuple[Any]], + strict: bool = False, + specialize_on: Optional[list[Any]] = None, +) -> None: """ Mark a tensor as having an unbacked dim. This changes the semantics of operations, we will always report the size does not equal zero/one, we will turn asserts @@ -565,7 +581,14 @@ def mark_unbacked(t, index, strict=False, specialize_on=None): @forbid_in_graph -def mark_dynamic(t, index, *, min=None, max=None, specialize_on=None): +def mark_dynamic( + t: Any, + index: Union[int, list[Any], tuple[Any]], + *, + min: Optional[int] = None, + max: Optional[int] = None, + specialize_on: Optional[list[Any]] = None, +) -> None: """ Mark a tensor as having a dynamic dim and set corresponding min and max range for the dim. @@ -620,7 +643,7 @@ def mark_dynamic(t, index, *, min=None, max=None, specialize_on=None): # TODO(voz): Should we bounds check? t._dynamo_dynamic_indices.add(index) - t._dynamo_dynamic_range.add(_DimRange(index, min, max)) + t._dynamo_dynamic_range.add(_DimRange(index, min, max)) # type: ignore[arg-type] # FX tracers don't respect @forbid_in_graph and choke on the following error since it passes in proxies: # TypeError: 'Attribute' object does not support item assignment @@ -636,7 +659,7 @@ def mark_dynamic(t, index, *, min=None, max=None, specialize_on=None): @forbid_in_graph -def maybe_mark_dynamic(t, index): +def maybe_mark_dynamic(t: Any, index: Union[int, list[Any], tuple[Any]]) -> None: """ Mark a tensor as having a dynamic dim, but don't enforce it (i.e., if this dimension ends up getting specialized, don't error). @@ -658,7 +681,9 @@ def maybe_mark_dynamic(t, index): maybe_mark_dynamic(t, i) -def mark_static(t, index=None): +def mark_static( + t: Any, index: Optional[Union[int, list[Any], tuple[Any]]] = None +) -> None: """ Mark a tensor as having a static dim or mark a nn module class as static. @@ -723,7 +748,7 @@ def mark_static(t, index=None): @forbid_in_graph -def mark_static_address(t, guard=True): +def mark_static_address(t: Any, guard: bool = True) -> None: """ Marks an input tensor whose data_ptr will not change across multiple calls to a dynamo-compiled function. This indicates to cudagraphs that an extra allocation @@ -742,7 +767,7 @@ def mark_static_address(t, guard=True): # One day, Dynamo will support tracing into einops directly (no allow_in_graph needed) # Note that PyTorch supports multiple versions of einops, so when that day comes, # we still need to be really careful about version matches. -def _allow_in_graph_einops(): +def _allow_in_graph_einops() -> None: import einops try: @@ -773,21 +798,26 @@ def _allow_in_graph_einops(): # Proxy class for torch._dynamo.config patching - so dynamo can identify context managers/decorators # created by patch_dynamo_config, compared to ones created by a raw torch._dynamo.config.patch. class DynamoConfigPatchProxy: - def __init__(self, config_patch): + def __init__(self, config_patch: Any) -> None: self.config_patch = config_patch @property - def changes(self): + def changes(self) -> dict[str, Any]: return self.config_patch.changes # Decorator implementation that simply sets up `self` as a context manager. # Placed in external_utils so that we can trace through it. __call__ = wrap_dunder_call_ctx_manager - def __enter__(self): + def __enter__(self) -> None: return self.config_patch.__enter__() - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: return self.config_patch.__exit__(exc_type, exc_val, exc_tb) @@ -819,7 +849,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): del config -def _patch_dynamo_config_check(changes: dict[str, Any]): +def _patch_dynamo_config_check(changes: dict[str, Any]) -> None: for k, v in changes.items(): if k not in _allowed_config_patches: raise ValueError( @@ -871,7 +901,7 @@ def dont_skip_tracing(fn: None = None) -> DynamoConfigPatchProxy: ... def dont_skip_tracing(fn: Callable[_P, _R]) -> Callable[_P, _R]: ... -def dont_skip_tracing(fn=None): +def dont_skip_tracing(fn: Optional[Any] = None) -> Any: """ Context manager/decorator to trace into functions intentionally marked by developers to be skipped when tracing. @@ -885,16 +915,21 @@ def dont_skip_tracing(fn=None): class SetFullgraphDecoratorContextManager: - def __init__(self, fullgraph): + def __init__(self, fullgraph: bool) -> None: self.fullgraph = fullgraph __call__ = wrap_dunder_call_ctx_manager - def __enter__(self): + def __enter__(self) -> None: self.prev_fullgraph = _get_error_on_graph_break() _set_error_on_graph_break(self.fullgraph) - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: _set_error_on_graph_break(self.prev_fullgraph) diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index 2ec7c5f7259f..eb315fc73190 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Device abstraction layer for TorchDynamo and Inductor backends. @@ -21,7 +19,7 @@ import time from collections.abc import Iterable from dataclasses import dataclass -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Literal, Optional, Union import torch @@ -44,17 +42,17 @@ class DeviceInterface: """ class device: - def __new__(cls, device: torch.types.Device): + def __new__(cls, device: torch.types.Device) -> Any: raise NotImplementedError class Event: - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError( "Event should be inherited from torch.Event, otherwise, it couldn't be captured by dynamo." ) class Stream: - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError( "Stream should be inherited from torch.Stream, otherwise, it couldn't be captured by dynamo." ) @@ -68,7 +66,7 @@ class Worker: """ @staticmethod - def set_device(device: int): + def set_device(device: int) -> None: raise NotImplementedError @staticmethod @@ -76,15 +74,15 @@ def current_device() -> int: raise NotImplementedError @staticmethod - def get_device_properties(device: torch.types.Device = None): + def get_device_properties(device: torch.types.Device = None) -> Any: raise NotImplementedError @staticmethod - def current_device(): + def current_device() -> int: raise NotImplementedError @staticmethod - def set_device(device: torch.types.Device): + def set_device(device: torch.types.Device) -> None: raise NotImplementedError @staticmethod @@ -96,7 +94,7 @@ def exchange_device(device: int) -> int: raise NotImplementedError @staticmethod - def device_count(): + def device_count() -> int: raise NotImplementedError @staticmethod @@ -104,19 +102,19 @@ def is_available() -> bool: raise NotImplementedError @staticmethod - def stream(stream: torch.Stream): + def stream(stream: torch.Stream) -> Any: raise NotImplementedError @staticmethod - def current_stream(): + def current_stream() -> torch.Stream: raise NotImplementedError @staticmethod - def set_stream(stream: torch.Stream): + def set_stream(stream: torch.Stream) -> None: raise NotImplementedError @staticmethod - def _set_stream_by_id(stream_id: int, device_index: int, device_type: int): + def _set_stream_by_id(stream_id: int, device_index: int, device_type: int) -> None: raise NotImplementedError @staticmethod @@ -124,19 +122,19 @@ def get_raw_stream(device_idx: int) -> int: raise NotImplementedError @staticmethod - def synchronize(device: torch.types.Device = None): + def synchronize(device: torch.types.Device = None) -> None: raise NotImplementedError @classmethod - def get_device_properties(cls, device: torch.types.Device = None): + def get_device_properties(cls, device: torch.types.Device = None) -> Any: return cls.Worker.get_device_properties(device) @staticmethod - def get_compute_capability(device: torch.types.Device = None): + def get_compute_capability(device: torch.types.Device = None) -> Any: raise NotImplementedError @staticmethod - def is_bf16_supported(including_emulation: bool = False): + def is_bf16_supported(including_emulation: bool = False) -> bool: raise NotImplementedError @classmethod @@ -188,11 +186,11 @@ def __init__( self.idx = index self.prev_idx = -1 - def __enter__(self): + def __enter__(self) -> None: if self.idx is not None: self.prev_idx = self.device_interface.exchange_device(self.idx) - def __exit__(self, type: Any, value: Any, traceback: Any): + def __exit__(self, type: Any, value: Any, traceback: Any) -> Literal[False]: if self.idx is not None: self.idx = self.device_interface.maybe_exchange_device(self.prev_idx) return False @@ -208,7 +206,7 @@ class CudaInterface(DeviceInterface): class Worker: @staticmethod - def set_device(device: int): + def set_device(device: int) -> None: caching_worker_current_devices["cuda"] = device @staticmethod @@ -218,7 +216,7 @@ def current_device() -> int: return torch.cuda.current_device() @staticmethod - def get_device_properties(device: torch.types.Device = None): + def get_device_properties(device: torch.types.Device = None) -> Any: if device is not None: if isinstance(device, str): device = torch.device(device) @@ -247,8 +245,8 @@ def get_device_properties(device: torch.types.Device = None): synchronize = staticmethod(torch.cuda.synchronize) get_device_properties = staticmethod(torch.cuda.get_device_properties) # type: ignore[assignment] get_raw_stream = staticmethod(get_cuda_stream) # type: ignore[assignment, arg-type] - exchange_device = staticmethod(torch.cuda._exchange_device) # type: ignore[arg-type] - maybe_exchange_device = staticmethod(torch.cuda._maybe_exchange_device) # type: ignore[arg-type] + exchange_device = staticmethod(torch.cuda._exchange_device) # type: ignore[arg-type, has-type] + maybe_exchange_device = staticmethod(torch.cuda._maybe_exchange_device) # type: ignore[arg-type, has-type] memory_allocated = staticmethod(torch.cuda.memory_allocated) is_bf16_supported = staticmethod(torch.cuda.is_bf16_supported) # type: ignore[arg-type] @@ -258,7 +256,7 @@ def is_available() -> bool: return torch.cuda.is_available() @staticmethod - def get_compute_capability(device: torch.types.Device = None): + def get_compute_capability(device: torch.types.Device = None) -> Union[int, str]: if torch.version.hip is None: major, min = torch.cuda.get_device_capability(device) return major * 10 + min @@ -303,7 +301,7 @@ class XpuInterface(DeviceInterface): class Worker: @staticmethod - def set_device(device: int): + def set_device(device: int) -> None: caching_worker_current_devices["xpu"] = device @staticmethod @@ -313,7 +311,7 @@ def current_device() -> int: return torch.xpu.current_device() @staticmethod - def get_device_properties(device: torch.types.Device = None): + def get_device_properties(device: torch.types.Device = None) -> Any: if device is not None: if isinstance(device, str): device = torch.device(device) @@ -352,7 +350,7 @@ def is_available() -> bool: return torch.xpu.is_available() @staticmethod - def get_compute_capability(device: torch.types.Device = None): + def get_compute_capability(device: torch.types.Device = None) -> Any: cc = torch.xpu.get_device_capability(device) return cc @@ -365,7 +363,7 @@ def is_triton_capable(device: torch.types.Device = None) -> bool: return True @staticmethod - def raise_if_triton_unavailable(evice: torch.types.Device = None) -> None: + def raise_if_triton_unavailable(device: torch.types.Device = None) -> None: import triton.backends if "intel" not in triton.backends.backends: @@ -379,18 +377,20 @@ class CpuDeviceProperties: class CpuInterface(DeviceInterface): class Event(torch.Event): - def __init__(self, enable_timing=True): + def __init__(self, enable_timing: bool = True) -> None: self.time = 0.0 - def elapsed_time(self, end_event) -> float: + def elapsed_time(self, end_event: Any) -> float: return (end_event.time - self.time) * 1000 - def record(self, stream=None): + def record(self, stream: Any = None) -> None: self.time = time.perf_counter() class Worker: @staticmethod - def get_device_properties(device: torch.types.Device = None): + def get_device_properties( + device: torch.types.Device = None, + ) -> CpuDeviceProperties: import multiprocessing cpu_count = multiprocessing.cpu_count() @@ -401,7 +401,7 @@ def is_available() -> bool: return True @staticmethod - def is_bf16_supported(including_emulation: bool = False): + def is_bf16_supported(including_emulation: bool = False) -> bool: return True @staticmethod @@ -409,15 +409,15 @@ def get_compute_capability(device: torch.types.Device = None) -> str: return "" @staticmethod - def get_raw_stream(device_idx) -> int: + def get_raw_stream(device_idx: Any) -> int: return 0 @staticmethod - def current_device(): + def current_device() -> int: return 0 @staticmethod - def synchronize(device: torch.types.Device = None): + def synchronize(device: torch.types.Device = None) -> None: pass @staticmethod @@ -450,7 +450,7 @@ def is_available() -> bool: return torch.backends.mps.is_available() @staticmethod - def current_device(): + def current_device() -> int: return 0 @staticmethod @@ -458,16 +458,16 @@ def get_compute_capability(device: torch.types.Device = None) -> str: return "" @staticmethod - def synchronize(device: torch.types.Device = None): + def synchronize(device: torch.types.Device = None) -> None: torch.mps.synchronize() class Worker: @staticmethod - def get_device_properties(device: torch.types.Device = None): + def get_device_properties(device: torch.types.Device = None) -> dict[str, Any]: return {} @staticmethod - def current_device(): + def current_device() -> int: return 0 @@ -477,7 +477,7 @@ def current_device(): def register_interface_for_device( device: Union[str, torch.device], device_interface: type[DeviceInterface] -): +) -> None: if isinstance(device, torch.device): device = device.type device_interfaces[device] = device_interface @@ -499,7 +499,7 @@ def get_registered_device_interfaces() -> Iterable[tuple[str, type[DeviceInterfa return device_interfaces.items() -def init_device_reg(): +def init_device_reg() -> None: global _device_initialized register_interface_for_device("cuda", CudaInterface) for i in range(torch.cuda.device_count()): diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 858aa402ca72..f47ca4185bed 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs # mypy: disable-error-code="method-assign" """ @@ -112,9 +111,20 @@ if TYPE_CHECKING: - from torch._subclasses import fake_tensor + from collections.abc import Iterable, Sequence - from .types import CacheEntry, DynamoCallback + from torch._dynamo.package import CompilePackage + from torch._dynamo.repro.after_dynamo import WrapBackendDebug + from torch._subclasses import fake_tensor + from torch.fx.node import Argument, Node, Target + + from .types import ( + CacheEntry, + DynamoCallback, + DynamoFrameType, + GuardFail, + GuardFilterEntry, + ) log = logging.getLogger(__name__) @@ -134,7 +144,7 @@ class Unset(Enum): unset = Unset.token -def _maybe_set_eval_frame(callback: DynamoCallback): +def _maybe_set_eval_frame(callback: DynamoCallback) -> DynamoCallback: # A wrapper on set_eval_frame that is guarded by a Justknob. # Users can disable torchDynamo by setting the JK to False. if not justknobs_check("pytorch/compiler:enable_compiler_set_eval_frame"): @@ -176,7 +186,7 @@ def _set_stance(stance: DynamoStance) -> DynamoStance: _EXAMPLE_INPUTS: Optional[dict[str, list[Any]]] = None -def get_example_inputs(key) -> list[Any]: +def get_example_inputs(key: str) -> list[Any]: global _EXAMPLE_INPUTS if _EXAMPLE_INPUTS is None: _EXAMPLE_INPUTS = {} @@ -187,7 +197,7 @@ def get_example_inputs(key) -> list[Any]: return _EXAMPLE_INPUTS[key] -def _callback_from_stance(callback): +def _callback_from_stance(callback: DynamoCallback) -> DynamoCallback: if _stance.stance == "default": # force_backend if _stance.backend is not None and callback not in (False, None): @@ -212,7 +222,9 @@ def _callback_from_stance(callback): if callback in (False, None): return callback - def fail_callback(frame, *args, **kwargs): + def fail_callback( + frame: DynamoFrameType, *args: Any, **kwargs: Any + ) -> ConvertFrameReturn: if trace_rules.check(frame.f_code): return ConvertFrameReturn() @@ -239,7 +251,9 @@ def fail_callback(frame, *args, **kwargs): raise RuntimeError(f"invalid torch.compile stance '{_stance}'") -def _create_wrapped_callback(compiler_fn): +def _create_wrapped_callback( + compiler_fn: CompilerFn, +) -> convert_frame.CatchErrorsWrapper: hooks = Hooks() return convert_frame.catch_errors_wrapper( convert_frame.convert_frame( # type: ignore[arg-type] @@ -250,7 +264,7 @@ def _create_wrapped_callback(compiler_fn): ) -def _get_or_add_example_inputs(frame): +def _get_or_add_example_inputs(frame: DynamoFrameType) -> list[Any]: key = frame.f_code.co_filename + str(frame.f_code.co_firstlineno) example_inputs = get_example_inputs(key) @@ -260,8 +274,10 @@ def _get_or_add_example_inputs(frame): return example_inputs -def _create_delayed_compile_callback(callback, stance): - def callback_fn(*args, **kwargs): +def _create_delayed_compile_callback( + callback: DynamoCallback, stance: str +) -> Callable[..., Any]: + def callback_fn(*args: Any, **kwargs: Any) -> convert_frame.ConvertFrameReturn: frame = args[0] example_inputs = _get_or_add_example_inputs(frame) @@ -278,7 +294,7 @@ def callback_fn(*args, **kwargs): dynamism = track_dynamism_across_examples(example_inputs) code_context.get_context(frame.f_code)["dynamism"] = dynamism - compiler_fn = callback._torchdynamo_orig_backend._torchdynamo_orig_backend + compiler_fn = callback._torchdynamo_orig_backend._torchdynamo_orig_backend # type: ignore[union-attr] return _create_wrapped_callback(compiler_fn)(*args, **kwargs) # to prevent cache miss due to different backend @@ -287,11 +303,11 @@ def callback_fn(*args, **kwargs): return callback_fn -def _is_skip_guard_eval_unsafe_stance(): +def _is_skip_guard_eval_unsafe_stance() -> bool: return _stance.skip_guard_eval_unsafe -def _reset_guarded_backend_cache(): +def _reset_guarded_backend_cache() -> None: global cached_backends for backend in cached_backends.values(): if hasattr(backend, "reset"): @@ -339,7 +355,7 @@ class OptimizedModule(torch.nn.Module): "_super_module_initialized", } - def __init__(self, mod: torch.nn.Module, dynamo_ctx) -> None: + def __init__(self, mod: torch.nn.Module, dynamo_ctx: _TorchDynamoContext) -> None: # NOTE: this must go first, because attribute reads/writes of `self` # uses `_orig_mod`, and sometimes users override `Module.__init__` to # do attribute reads/writes on `self`. @@ -357,7 +373,7 @@ def __init__(self, mod: torch.nn.Module, dynamo_ctx) -> None: self._initialize() self.training = self._orig_mod.training - def _initialize(self): + def _initialize(self) -> None: # Do this stuff in constructor to lower overhead slightly if isinstance(self.dynamo_ctx, DisableContext): # No need to check trace rules @@ -381,7 +397,7 @@ def _initialize(self): self._forward = self.forward self.forward = self._call_lazy_check - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: if torch.nn.modules.module._has_any_global_hook(): warnings.warn( "Using `torch.compile(module)` when there are global hooks on " @@ -394,37 +410,39 @@ def __call__(self, *args, **kwargs): ) return super().__call__(*args, **kwargs) - def __reduce__(self): + def __reduce__( + self, + ) -> tuple[type[OptimizedModule], tuple[torch.nn.Module, _TorchDynamoContext]]: return (self.__class__, (self._orig_mod, self.dynamo_ctx)) - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: state = dict(self.__dict__) state.pop("forward", None) state.pop("__call__", None) return state - def __setstate__(self, state): + def __setstate__(self, state: dict[str, Any]) -> None: self.__dict__ = state self._initialize() @property - def training(self): + def training(self) -> bool: return self._orig_mod.training @training.setter - def training(self, value): + def training(self, value: bool) -> None: # Ignore the `training` mutation in `super().__init__()`, since that's # setting the default on `nn.Module`, but we are mirroring the # `training` attr in `self._orig_mod`. if self._super_module_initialized: self._orig_mod.training = value - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if name == "_orig_mod": return self._modules["_orig_mod"] return getattr(self._orig_mod, name) - def __setattr__(self, name, val) -> None: + def __setattr__(self, name: str, val: Any) -> None: # Allow patching over class attributes if hasattr(type(self), name): return super().__setattr__(name, val) @@ -433,7 +451,7 @@ def __setattr__(self, name, val) -> None: return super().__setattr__(name, val) return setattr(self._orig_mod, name, val) - def __delattr__(self, name): + def __delattr__(self, name: str) -> None: # This mirrors `__setattr__` if hasattr(type(self), name): return super().__delattr__(name) @@ -442,7 +460,7 @@ def __delattr__(self, name): return super().__delattr__(name) return delattr(self._orig_mod, name) - def _call_lazy_check(self, *args, **kwargs): + def _call_lazy_check(self, *args: Any, **kwargs: Any) -> Any: if ( hasattr(self._orig_mod, "_initialize_hook") and hasattr(self._orig_mod, "_infer_parameters") @@ -455,14 +473,14 @@ def _call_lazy_check(self, *args, **kwargs): self._orig_mod._infer_parameters(self._orig_mod, args, kwargs) return self._forward(*args, **kwargs) - def __dir__(self): + def __dir__(self) -> list[str]: orig_mod_attrs = self._orig_mod.__dir__() return orig_mod_attrs + [ attr for attr in super().__dir__() if attr not in orig_mod_attrs ] -def remove_from_cache(f): +def remove_from_cache(f: Any) -> None: """ Make sure f.__code__ is not cached to force a recompile """ @@ -479,15 +497,17 @@ def remove_from_cache(f): log.warning("could not determine __code__ for %s", f) -def nothing(): +def nothing() -> None: pass -def always_false(): +def always_false() -> bool: return False -def innermost_fn(fn, unaltered_fn_attr="_torchdynamo_orig_callable"): +def innermost_fn( + fn: Callable[..., Any], unaltered_fn_attr: str = "_torchdynamo_orig_callable" +) -> Callable[..., Any]: """ In case of nesting of _TorchDynamoContext calls, find the innermost function. TorchDynamo caches on fn.__code__ object, so its necessary to find @@ -502,7 +522,7 @@ def innermost_fn(fn, unaltered_fn_attr="_torchdynamo_orig_callable"): return unaltered_fn -def make_set_enable_dynamic(enable: bool): +def make_set_enable_dynamic(enable: bool) -> Any: assert isinstance(enable, bool) if enable: # Assume everything is dynamic by default @@ -524,12 +544,12 @@ class DynamoTLS(threading.local): dynamo_tls = DynamoTLS() -def clear_dynamo_tls(): +def clear_dynamo_tls() -> None: dynamo_tls.traced_frame_infos.clear() @atexit.register -def _log_traced_frames(): +def _log_traced_frames() -> None: """ At program exit, log all of the frames Dynamo has attempted to trace from, excluding the continuation frames generated by Dynamo. @@ -540,7 +560,7 @@ def _log_traced_frames(): log.info(msg) -def guard_collectives_hook(guard_eval_result): +def guard_collectives_hook(guard_eval_result: bool) -> bool: import torch.distributed as dist from torch._dynamo.utils import dynamo_timed @@ -568,16 +588,18 @@ class _TorchDynamoContext: def __init__( self, callback: DynamoCallback, - on_enter=nothing, - backend_ctx_ctor=null_context, - patch_fn=nothing, - first_ctx=False, + on_enter: Callable[[], Any] = nothing, + backend_ctx_ctor: Callable[ + [], contextlib.AbstractContextManager[Any] + ] = null_context, + patch_fn: Callable[[], Any] = nothing, + first_ctx: bool = False, *, - error_on_graph_break=False, - export=False, - dynamic=None, - compiler_config=None, - package=None, + error_on_graph_break: bool = False, + export: bool = False, + dynamic: Optional[bool] = None, + compiler_config: Optional[Any] = None, + package: Optional[CompilePackage] = None, ) -> None: super().__init__() assert callable(callback) or callback is False or callback is None @@ -595,15 +617,15 @@ def __init__( patch_fn() # Save the backends so that we can reset them during torch._dynamo.reset - backend = innermost_fn(callback, unaltered_fn_attr="_torchdynamo_orig_backend") - cached_backends.setdefault(id(backend), backend) + backend = innermost_fn(callback, unaltered_fn_attr="_torchdynamo_orig_backend") # type: ignore[arg-type] + cached_backends.setdefault(id(backend), backend) # type: ignore[arg-type] if dynamic is not None: self.enter_exit_hooks.append(make_set_enable_dynamic(dynamic)) if on_enter is not nothing: # this case is not common - def call_on_enter(): + def call_on_enter() -> Callable[[], None]: on_enter() return nothing @@ -611,14 +633,14 @@ def call_on_enter(): if backend_ctx_ctor is not contextlib.nullcontext: # this case is not common - def call_backend_ctx(): + def call_backend_ctx() -> functools.partial[Optional[bool]]: ctx = backend_ctx_ctor() ctx.__enter__() return functools.partial(ctx.__exit__, None, None, None) self.enter_exit_hooks.append(call_backend_ctx) - def __enter__(self): + def __enter__(self) -> None: if config.raise_on_ctx_manager_usage: raise RuntimeError( "torch._dynamo.optimize(...) is used with a context manager. " @@ -632,7 +654,12 @@ def __enter__(self): ) _maybe_set_eval_frame(_callback_from_stance(self.callback)) - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[types.TracebackType], + ) -> Optional[bool]: assert self.prior is not unset set_eval_frame(None) set_skip_guard_eval_unsafe(self.prior_skip_guard_eval_unsafe) @@ -641,10 +668,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.cleanup_fns.clear() _maybe_set_eval_frame(_callback_from_stance(self.prior)) self.prior = unset + return None - def __call__(self, fn): + def __call__(self, fn: Any) -> Any: # public api for compiler config/options - def get_compiler_config(): + def get_compiler_config() -> Any: return self.compiler_config from .package import DynamoCache @@ -721,19 +749,18 @@ def get_compiler_config(): # call to a builtin without a frame for us to capture fn = external_utils.wrap_inline(fn) - def do_nothing(*arg, **kwargs): + def do_nothing(*arg: Any, **kwargs: Any) -> None: pass + callback: Callable[..., Any] = do_nothing if hasattr(self, "callback"): - callback = self.callback - else: - callback = do_nothing + callback = self.callback # type: ignore[assignment] is_jit_tracing = torch._C._is_tracing is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing @functools.wraps(fn) - def compile_wrapper(*args, **kwargs): + def compile_wrapper(*args: Any, **kwargs: Any) -> Any: prior = set_eval_frame(None) try: if is_fx_tracing(): @@ -861,20 +888,20 @@ def compile_wrapper(*args, **kwargs): class OptimizeContext(_TorchDynamoContext): def __init__( self, - callback, - backend_ctx_ctor, - first_ctx=False, + callback: DynamoCallback, + backend_ctx_ctor: Callable[[], contextlib.AbstractContextManager[Any]], + first_ctx: bool = False, *, - error_on_graph_break=False, - export=False, - dynamic=None, - compiler_config=None, + error_on_graph_break: bool = False, + export: bool = False, + dynamic: Optional[bool] = None, + compiler_config: Optional[Any] = None, rebuild_ctx: Optional[ Callable[[], Union[OptimizeContext, _NullDecorator]] ] = None, - package=None, + package: Optional[CompilePackage] = None, ) -> None: - def on_enter(): + def on_enter() -> None: install_generation_tagging_init() super().__init__( @@ -895,7 +922,7 @@ def on_enter(): if _dynamic is None: _dynamic = not torch._dynamo.config.assume_static_by_default - def call_compiled_autograd(): + def call_compiled_autograd() -> functools.partial[Optional[bool]]: assert rebuild_ctx is not None compiler_fn = rebuild_ctx() ctx = torch._dynamo.compiled_autograd._enable( @@ -906,7 +933,9 @@ def call_compiled_autograd(): self.enter_exit_hooks.append(call_compiled_autograd) - def __reduce__(self): + def __reduce__( + self, + ) -> tuple[type[OptimizeContext], tuple[Any, ...], dict[str, Any]]: return ( self.__class__, (self.callback, self._backend_ctx_ctor, self.first_ctx), @@ -921,12 +950,12 @@ def __reduce__(self): class RunOnlyContext(_TorchDynamoContext): def __init__(self) -> None: # cudagraph trees relies on generation increment - def on_enter(): + def on_enter() -> None: torch._dynamo.mutation_guard.GenerationTracker.generation += 1 super().__init__(callback=False, on_enter=on_enter) - def __reduce__(self): + def __reduce__(self) -> tuple[type[RunOnlyContext], tuple[Any, ...]]: return (self.__class__, ()) @@ -936,7 +965,7 @@ def __init__(self, msg: Optional[str] = None, wrapping: bool = True) -> None: self.msg = msg self.wrapping = wrapping - def __call__(self, fn): + def __call__(self, fn: Callable[..., Any]) -> Callable[..., Any]: # Earlier this code was in the base class _TorchDynamoContext. But we # moved it here to have better code organization. For disable, we just # want the callback to be None. We don't have to check trace_rules or @@ -967,7 +996,7 @@ def __call__(self, fn): f"A callable function is expected, but {type(fn)} is provided." ) - def _fn(*args, **kwargs): + def _fn(*args: Any, **kwargs: Any) -> Any: prior = set_eval_frame(None) try: _maybe_set_eval_frame(_callback_from_stance(self.callback)) @@ -995,21 +1024,23 @@ def _fn(*args, **kwargs): return _fn - def __reduce__(self): + def __reduce__(self) -> tuple[type[DisableContext], tuple[Any, ...]]: return (self.__class__, ()) def _optimize_catch_errors( - compile_fn, + compile_fn: convert_frame.ConvertFrameProtocol, hooks: Hooks, - backend_ctx_ctor=null_context, - error_on_graph_break=False, - export=False, - dynamic=None, - compiler_config=None, - rebuild_ctx=None, - package=None, -): + backend_ctx_ctor: Callable[ + [], contextlib.AbstractContextManager[Any] + ] = null_context, + error_on_graph_break: bool = False, + export: bool = False, + dynamic: Optional[bool] = None, + compiler_config: Optional[Any] = None, + rebuild_ctx: Optional[Callable[[], Union[OptimizeContext, _NullDecorator]]] = None, + package: Optional[CompilePackage] = None, +) -> OptimizeContext: return OptimizeContext( convert_frame.catch_errors_wrapper(compile_fn, hooks), backend_ctx_ctor=backend_ctx_ctor, @@ -1023,11 +1054,17 @@ def _optimize_catch_errors( ) -def get_compiler_fn(compiler_fn): +def get_compiler_fn( + compiler_fn: Union[str, Callable[..., Any], None], +) -> WrapBackendDebug: from .repro.after_dynamo import wrap_backend_debug - if hasattr(compiler_fn, "compiler_name"): - compiler_str = compiler_fn.compiler_name + if compiler_fn is None: + # Special case None to avoid crashing in hasattr + compiler_str = None + elif hasattr(compiler_fn, "compiler_name"): + compiler_str = compiler_fn.compiler_name # type: ignore[union-attr] + assert isinstance(compiler_str, str) elif isinstance(compiler_fn, str): compiler_str = compiler_fn else: @@ -1037,14 +1074,14 @@ def get_compiler_fn(compiler_fn): class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg] - def __call__(self, fn): + def __call__(self, fn: Callable[..., Any]) -> Callable[..., Any]: assert callable(fn), ( f"A callable function is expected, but {type(fn)} is provided." ) return fn -def check_if_dynamo_supported(): +def check_if_dynamo_supported() -> None: if sys.version_info >= (3, 14): raise RuntimeError("Python 3.14+ not yet supported for torch.compile") elif sysconfig.get_config_var("Py_GIL_DISABLED") == 1 and sys.version_info < ( @@ -1058,7 +1095,7 @@ def check_if_dynamo_supported(): ) -def is_dynamo_supported(): +def is_dynamo_supported() -> bool: try: check_if_dynamo_supported() return True @@ -1066,11 +1103,11 @@ def is_dynamo_supported(): return False -def check_if_inductor_supported(): +def check_if_inductor_supported() -> None: check_if_dynamo_supported() -def is_inductor_supported(): +def is_inductor_supported() -> bool: try: check_if_inductor_supported() return True @@ -1078,15 +1115,15 @@ def is_inductor_supported(): return False -def check_for_incompatible_configs(): +def check_for_incompatible_configs() -> None: # Some of the configs should be mutually exclusive assert not (config.suppress_errors and config.fail_on_recompile_limit_hit), ( "Dynamo configs suppress_error and fail_on_recompile_limit_hit can not both be active at the same time." ) -def optimize(*args, **kwargs): - def rebuild_ctx(): +def optimize(*args: Any, **kwargs: Any) -> Union[OptimizeContext, _NullDecorator]: + def rebuild_ctx() -> Union[OptimizeContext, _NullDecorator]: ca_kwargs_override = config.compiled_autograd_kwargs_override if ca_kwargs_override: # NOTE: The process of translating other `torch.compile` kwargs to `torch._dynamo.optimize` kwargs @@ -1102,15 +1139,15 @@ def rebuild_ctx(): def _optimize( rebuild_ctx: Callable[[], Union[OptimizeContext, _NullDecorator]], - backend="inductor", + backend: Union[str, Callable[..., Any]] = "inductor", *, - nopython=False, - guard_export_fn=None, - guard_fail_fn=None, - guard_filter_fn=None, - disable=False, - dynamic=None, - package=None, + nopython: bool = False, + guard_export_fn: Optional[Callable[[_guards.GuardsSet], None]] = None, + guard_fail_fn: Optional[Callable[[GuardFail], None]] = None, + guard_filter_fn: Optional[Callable[[list[GuardFilterEntry]], list[bool]]] = None, + disable: bool = False, + dynamic: Optional[bool] = None, + package: Optional[CompilePackage] = None, ) -> Union[OptimizeContext, _NullDecorator]: """ The main entrypoint of TorchDynamo. Do graph capture and call @@ -1195,8 +1232,10 @@ def toy_example(a, b): ... # TODO(voz): Consider making "explain" output alongside a run / part of a run @patch("torch._dynamo.symbolic_convert.explain", True) -def explain(f, *extra_args, **extra_kwargs): - def inner(*args, **kwargs): +def explain(f: Callable[..., Any], *extra_args: Any, **extra_kwargs: Any) -> Any: + from .backends.debugging import ExplainOutput + + def inner(*args: Any, **kwargs: Any) -> ExplainOutput: # TODO(voz): Do we want a decorator for this? from . import reset # type: ignore[attr-defined] @@ -1209,8 +1248,8 @@ def inner(*args, **kwargs): out_guards: list[_guards.Guard] = [] def dynamo_graph_accumulating_compiler( - gm: torch.fx.GraphModule, example_inputs - ): + gm: torch.fx.GraphModule, example_inputs: Any + ) -> Callable[..., Any]: from .backends.debugging import _explain_graph_detail nonlocal graphs @@ -1224,7 +1263,7 @@ def dynamo_graph_accumulating_compiler( return gm.forward - def guard_export_print(guards): + def guard_export_print(guards: Iterable[_guards.Guard]) -> None: nonlocal out_guards out_guards.extend(guards) @@ -1242,7 +1281,6 @@ def guard_export_print(guards): # TODO(voz): Do we want a decorator for this? reset() - from .backends.debugging import ExplainOutput return ExplainOutput( graphs, @@ -1272,9 +1310,9 @@ class FlattenInputOutputSignature(torch.fx.Transformer): def __init__( self, m: torch.fx.GraphModule, - flat_args: tuple[Any], + flat_args: list[Any], matched_input_elements_positions: list[int], - flat_results: list[Any], + flat_results: Sequence[Any], matched_output_elements_positions: list[int], example_fake_inputs: list[torch.Tensor], flat_args_dynamic_dims: list[set[int]], @@ -1322,7 +1360,9 @@ def __init__( self.matched_output_elements_positions = matched_output_elements_positions self.flat_results = flat_results - def placeholder(self, target, args, kwargs): + def placeholder( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: arg = next(self.old_args_gen) if "val" in self.current_node.meta: arg.node.meta["val"] = self.current_node.meta["val"] @@ -1337,9 +1377,11 @@ def placeholder(self, target, args, kwargs): ] return arg - def output(self, target, args, kwargs): + def output( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: dynamo_result_flat = args[0] - lookup = [*dynamo_result_flat, *self.new_args] + lookup = [*dynamo_result_flat, *self.new_args] # type: ignore[misc] new_results_flat = [] for i in range(len(self.flat_results)): if self.matched_output_elements_positions[i] is not None: @@ -1352,7 +1394,7 @@ def output(self, target, args, kwargs): new_results_flat.append(const_val) return super().output(target, (new_results_flat,), {}) - def run_node(self, n): + def run_node(self, n: Node) -> Any: self.current_node = n result_proxy = super().run_node(n) if "val" in self.current_node.meta: @@ -1372,7 +1414,7 @@ def run_node(self, n): ) return result_proxy - def transform(self): + def transform(self) -> torch.fx.GraphModule: result_gm = super().transform() if "dynamo_flat_name_to_original_fqn" in self.module.meta: # type: ignore[operator] result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[ # type: ignore[index] @@ -1391,15 +1433,17 @@ class ExportResult(NamedTuple): # NOTE: this function only supports graphs created by Dynamo's OutputGraph module -def check_signature_rewritable(graph): +def check_signature_rewritable(graph: torch.fx.GraphModule) -> None: input_errors = [] for node in graph.graph.find_nodes(op="placeholder"): # set in OutputGraph._call_user_compiler assert hasattr(node, "_dynamo_source") assert hasattr(graph, "_source_to_user_stacks") - source = node._dynamo_source - user_stacks = graph._source_to_user_stacks.get(source) + # NOTE: We can safely ignore these type warnings if and only if + # the function is made from OutputGraph (checked in the assertions) + source = node._dynamo_source # type: ignore[attr-defined] + user_stacks = graph._source_to_user_stacks.get(source) # type: ignore[operator, union-attr] if user_stacks is None: continue assert len(user_stacks) > 0 @@ -1436,20 +1480,22 @@ def check_signature_rewritable(graph): def rewrite_signature( - f_sig, - graph, - fake_mode, - flat_args, - in_spec, - example_fake_inputs, - graph_captured_input, - graph_captured_output, - dynamo_traced_result, - flat_args_dynamic_dims, -): + f_sig: inspect.Signature, + graph: torch.fx.GraphModule, + fake_mode: Optional[fake_tensor.FakeTensorMode], + flat_args: list[Any], + in_spec: pytree.TreeSpec, + example_fake_inputs: list[Any], + graph_captured_input: Iterable[Any], + graph_captured_output: Optional[Iterable[Any]], + dynamo_traced_result: Any, + flat_args_dynamic_dims: list[set[int]], +) -> torch.fx.GraphModule: orig_args, orig_kwargs = pytree.tree_unflatten(flat_args, in_spec) - def check_user_input_output(flat_values, error_type): + def check_user_input_output( + flat_values: list[Any], error_type: UserErrorType + ) -> None: supported_types = [ torch.Tensor, torch.SymInt, @@ -1459,7 +1505,7 @@ def check_user_input_output(flat_values, error_type): _IntWrapper, ] + list(common_constant_types) - def is_supported_type(val): + def is_supported_type(val: Any) -> bool: return isinstance(val, tuple(supported_types)) value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output" @@ -1485,7 +1531,7 @@ def is_supported_type(val): flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result) check_user_input_output(flat_results_traced, UserErrorType.INVALID_OUTPUT) - def check_optional_input_and_error(f_sig: inspect.Signature): + def check_optional_input_and_error(f_sig: inspect.Signature) -> None: # Check if function has optional input. for name, param in f_sig.parameters.items(): if param.default is not inspect.Parameter.empty: @@ -1501,7 +1547,9 @@ def check_optional_input_and_error(f_sig: inspect.Signature): case_name="optional_input", ) - def produce_matching(debug_type, sources, candidates): + def produce_matching( + debug_type: str, sources: Iterable[Any], candidates: Iterable[Any] + ) -> list[Optional[int]]: matched_elements_positions: list[Optional[int]] = [] dict_of_source_vals = {} for i, val in enumerate(sources): @@ -1534,17 +1582,19 @@ def produce_matching(debug_type, sources, candidates): new_graph = FlattenInputOutputSignature( graph, flat_args, - matched_input_elements_positions, + matched_input_elements_positions, # type: ignore[arg-type] flat_results_traced, - matched_output_elements_positions, + matched_output_elements_positions, # type: ignore[arg-type] example_fake_inputs, flat_args_dynamic_dims, fake_mode, ).transform() # Make dynamo graph to have same input/output spec as user code - def argument_names(f_sig, args, kwargs) -> list[str]: - def signature_to_fullargspec(sig: inspect.Signature): + def argument_names( + f_sig: inspect.Signature, args: list[Any], kwargs: dict[str, Any] + ) -> list[str]: + def signature_to_fullargspec(sig: inspect.Signature) -> inspect.FullArgSpec: # Get a list of Parameter objects from the Signature object params = list(sig.parameters.values()) # Separate positional arguments, keyword-only arguments and varargs/varkw @@ -1638,7 +1688,7 @@ def signature_to_fullargspec(sig: inspect.Signature): def export( f: Callable[..., Any], - *extra_args, + *extra_args: Any, aten_graph: bool = False, pre_dispatch: bool = False, decomposition_table: Optional[ @@ -1654,7 +1704,7 @@ def export( allow_complex_guards_as_runtime_asserts: bool = False, _log_export_usage: bool = True, constraints: Optional[list[Constraint]] = None, - **extra_kwargs, + **extra_kwargs: Any, ) -> Callable[..., ExportResult]: """ Export an input function f to a format that can be executed outside of PyTorch using the FX graph. @@ -1718,7 +1768,7 @@ def export( _assume_static_by_default = assume_static_by_default _constraints = constraints - def inner(*args, **kwargs): + def inner(*args: Any, **kwargs: Any) -> ExportResult: if not _constraints: combined_args = _combine_args(_f, args, kwargs) constraints = _process_dynamic_shapes(combined_args, dynamic_shapes) @@ -1738,7 +1788,7 @@ def inner(*args, **kwargs): assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True" f = innermost_fn(f) call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f - original_signature = inspect.signature(call_to_inspect) + original_signature = inspect.signature(call_to_inspect) # type: ignore[arg-type] graph = None out_guards = None graph_captured_input = None @@ -1746,18 +1796,18 @@ def inner(*args, **kwargs): fake_mode = None result_traced = None - def guard_export_print(guards: _guards.GuardsSet): + def guard_export_print(guards: _guards.GuardsSet) -> None: nonlocal out_guards assert out_guards is None, ( "whole graph export entails exactly one guard export" ) out_guards = guards - example_inputs = [] + example_inputs: list[Any] = [] def dynamo_normalization_capturing_compiler( - gm: torch.fx.GraphModule, inner_example_inputs - ): + gm: torch.fx.GraphModule, inner_example_inputs: list[Any] + ) -> Callable[..., Any]: nonlocal graph assert graph is None, ( "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph." @@ -1773,7 +1823,7 @@ def dynamo_normalization_capturing_compiler( fake_mode = _guards.detect_fake_mode() example_inputs = inner_example_inputs - def result_capturing_wrapper(*graph_inputs): + def result_capturing_wrapper(*graph_inputs: Any) -> Any: nonlocal graph_captured_result nonlocal graph_captured_input @@ -1815,7 +1865,14 @@ def result_capturing_wrapper(*graph_inputs): value, static_shapes=True ) - def fakify_with_ambient(path, t): + from torch._export.non_strict_utils import ( + key_path_to_source, + KeyPath, + ) + + def fakify_with_ambient( + path: KeyPath, t: Union[torch.Tensor, _IntWrapper, Any] + ) -> Any: if isinstance(t, torch.Tensor): return ambient_fake_mode.from_tensor(t, static_shapes=True) elif isinstance(t, _IntWrapper): @@ -1828,10 +1885,6 @@ def fakify_with_ambient(path, t): _DimHintType.AUTO, ) ): # type: ignore[union-attr] - from torch._export.non_strict_utils import ( - key_path_to_source, - ) - source = key_path_to_source(path) symint = ambient_fake_mode.shape_env.create_unspecified_symint_and_symbol( # type: ignore[union-attr] t.val, source, DimDynamic.DYNAMIC @@ -1989,7 +2042,7 @@ def fakify_with_ambient(path, t): if aten_graph: # Running graph with interpreter is needed for propagating the stack_trace - def graph_with_interpreter(*args): + def graph_with_interpreter(*args: Any) -> Any: with torch.fx.traceback.preserve_node_meta(): return torch.fx.Interpreter(graph).run(*args) # type: ignore[arg-type] @@ -2039,12 +2092,12 @@ def graph_with_interpreter(*args): flat_args, in_spec, example_fake_inputs, - graph_captured_input, + graph_captured_input, # type: ignore[arg-type] graph_captured_result, result_traced, # type: ignore[possibly-undefined] flat_args_dynamic_dims, ) - return ExportResult(graph, out_guards) # type: ignore[arg-type] + return ExportResult(graph, out_guards) if extra_args or extra_kwargs: warnings.warn( @@ -2054,19 +2107,19 @@ def graph_with_interpreter(*args): FutureWarning, stacklevel=2, ) - return inner(*extra_args, **extra_kwargs) + return inner(*extra_args, **extra_kwargs) # type: ignore[return-value] else: return inner -def optimize_assert(*args, **kwargs): +def optimize_assert(*args: Any, **kwargs: Any) -> OptimizeContext: if "rebuild_ctx" in kwargs and kwargs["rebuild_ctx"] is not None: # called from optimize rebuild_ctx = kwargs["rebuild_ctx"] del kwargs["rebuild_ctx"] else: - def rebuild_ctx(): + def rebuild_ctx() -> OptimizeContext: return optimize_assert(*args, **kwargs) return _optimize_assert(rebuild_ctx, *args, **kwargs) @@ -2074,14 +2127,14 @@ def rebuild_ctx(): def _optimize_assert( rebuild_ctx: Callable[[], OptimizeContext], - backend, + backend: Union[str, Callable[..., Any], None], *, - hooks=Hooks(None, None, None), - export=False, - export_constraints=None, - dynamic=None, - package=None, -): + hooks: Hooks = Hooks(None, None, None), + export: bool = False, + export_constraints: Optional[Any] = None, + dynamic: Optional[bool] = None, + package: Optional[CompilePackage] = None, +) -> OptimizeContext: """ The same as `torch._dynamo.optimize(backend, nopython=True)`, but ignores symbolic_convert.error_on_graph_break setting. @@ -2123,7 +2176,7 @@ def _optimize_assert( class TorchPatcher: @staticmethod @functools.cache - def patch(): + def patch() -> None: # A better way to disable the following would be decorate the source # functions with @torch._disable_dynamo. However, this causes issues # with torch.deploy internally. @@ -2216,17 +2269,19 @@ def patch(): ) @staticmethod - def suppress_torch_distributed_warnings(fn): - def inner_fn(*args, **kwargs): - warnings.filterwarnings( - "ignore", category=UserWarning, module="torch.distributed" - ) - return fn(*args, **kwargs) + def suppress_torch_distributed_warnings( + fn: Callable[..., Any], + ) -> Callable[..., Any]: + def inner_fn(*args: Any, **kwargs: Any) -> Any: + with torch._logging.hide_warnings( + torch._logging._internal.user_warning_filter + ): + return fn(*args, **kwargs) return inner_fn -def skip_code(code: types.CodeType): +def skip_code(code: types.CodeType) -> None: set_code_exec_strategy( code, FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT) ) diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index 8a0c6bfc2b4b..76bf400245c5 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -532,11 +532,6 @@ def unimplemented_v2( raise Unsupported(msg) -def warning(msg: str) -> None: - counters["warnings"][msg] += 1 - assert msg != os.environ.get("BREAK", False) - - # KeyError has special handling for its args # see https://github.com/python/cpython/blob/3.11/Objects/exceptions.c#L2534 for details class KeyErrorMsg: diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py index c4fbc62ea5db..f48c14862ac0 100644 --- a/torch/_dynamo/external_utils.py +++ b/torch/_dynamo/external_utils.py @@ -1,5 +1,3 @@ -# This module contains functions that *will be allowed* by dynamo - """ This module contains utility functions that are explicitly allowed to be called during TorchDynamo compilation. These functions are carefully vetted to ensure they work diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index 58085696b78b..0bbdd91b6ae2 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -2499,5 +2499,15 @@ "This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround." ] } + ], + "GB0251": [ + { + "Gb_type": "Unsupported output type for nonstrict_trace-ed function", + "Context": "Function: {fn.__name__}", + "Explanation": "For `nonstrict_trace`-ed functions, only basic types (e.g., torch.Tensor, int, list) are allowed as output. The result of this call contains an unsupported type.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } ] } \ No newline at end of file diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 33ddb4c303bc..c6444d3acc6c 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -104,6 +104,7 @@ ChainedSource, ConstantSource, ConstDictKeySource, + DataclassFieldsSource, DefaultsSource, DictGetItemSource, DictSubclassGetItemSource, @@ -146,6 +147,7 @@ from .utils import ( builtin_dict_keys, common_constant_types, + dataclass_fields, dict_keys, get_custom_getattr, get_torch_function_mode_stack, @@ -164,7 +166,7 @@ ) -guard_manager_testing_hook_fn: Optional[Callable[[Any, Any], Any]] = None +guard_manager_testing_hook_fn: Optional[Callable[[Any, Any, Any], Any]] = None try: import numpy as np @@ -309,6 +311,7 @@ def get_manager_line(self, guard_manager, accessor_str=None): s = t + ": source=" + source if accessor_str: s += ", " + accessor_str + s += f", type={guard_manager.type_of_guarded_value()}" return s def construct_dict_manager_string(self, mgr, body): @@ -451,6 +454,7 @@ def _get_closure_vars(): "___tuple_iterator_len": tuple_iterator_len, "___normalize_range_iter": normalize_range_iter, "___tuple_iterator_getitem": tuple_iterator_getitem, + "___dataclass_fields": dataclass_fields, "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at, "__math_isnan": math.isnan, "__numpy_isnan": None if np is None else np.isnan, @@ -700,6 +704,9 @@ def __init__( ] = {} self._cached_duplicate_input_guards: set[tuple[str, str]] = set() self.serialization_mode = serialization_mode + self.guard_nn_modules = config.guard_nn_modules and justknobs_check( + "pytorch/compiler:guard_nn_modules" + ) def guard_on_dict_keys_and_ignore_order(self, example_value, guard): dict_mgr = self.get_guard_manager(guard) @@ -1325,6 +1332,14 @@ def get_guard_manager_from_source(self, source): example_value=example_value, guard_manager_enum=guard_manager_enum, ) + elif istype(source, DataclassFieldsSource): + assert base_guard_manager + out = base_guard_manager.lambda_manager( + python_lambda=lambda x: dataclass_fields(x), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) else: raise AssertionError( f"missing guard manager builder {source} - {source.name()}" @@ -1588,7 +1603,7 @@ def id_match_unchecked(self, guard: Guard): val = self.get(guard.name) id_val = self.id_ref(val, guard.name) code = f"___check_obj_id({ref}, {id_val})" - self._set_guard_export_info(guard, [code]) + self._set_guard_export_info(guard, [code], provided_func_name="ID_MATCH") self.get_guard_manager(guard).add_id_match_guard( id_val, get_verbose_code_parts(code, guard) @@ -1841,7 +1856,9 @@ def NN_MODULE(self, guard: Guard): val = self.get(guard.name) if hasattr(val, "training"): assert istype(val.training, bool) - self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) + if not self.guard_nn_modules: + # If guard_nn_modules is true, we will guard on the right set of guards + self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) else: exc.unimplemented_v2( gb_type="Attempted to guard on uninitialized nn.Module", @@ -2468,7 +2485,9 @@ def TENSOR_MATCH(self, guard: Guard, value=None): self._set_guard_export_info(guard, code) # A util that in the case of export, adds data onto guards - def _set_guard_export_info(self, guard, code_list, provided_guarded_object=None): + def _set_guard_export_info( + self, guard, code_list, provided_guarded_object=None, provided_func_name=None + ): # WARNING: It is important that cur_frame/caller do NOT stay in # the current frame, because they will keep things live longer # than they should. See TestMisc.test_release_module_memory @@ -2477,7 +2496,7 @@ def _set_guard_export_info(self, guard, code_list, provided_guarded_object=None) caller = cur_frame.f_back del cur_frame assert caller is not None - func_name = caller.f_code.co_name + func_name = provided_func_name or caller.f_code.co_name del caller # We use func_name for export, so might as well get a nice defensive check out of it assert func_name in self.__class__.__dict__, ( @@ -2837,6 +2856,32 @@ def __init__( if not justknobs_check("pytorch/compiler:guard_nn_modules"): log.warning("guard_nn_modules is turned off using justknobs killswitch") + # TODO Be more explicit about the behavior for the users. + if ( + torch._dynamo.config.caching_precompile + and self.guards_serialization_mode != "load" + ): + _guard_filter_fn = guard_filter_fn or (lambda gs: [True for g in gs]) + + def guard_filter_fn(guards): + ret = [] + for keep, g in zip(_guard_filter_fn(guards), guards): + if not keep: + ret.append(False) + elif ( + g.guard_type in ("ID_MATCH", "CLOSURE_MATCH", "WEAKREF_ALIVE") + or "ID_MATCH" in g.derived_guard_types + ): + log.warning( + "%s guard on %s is dropped with caching_precompile=True.", + g.guard_type, + g.orig_guard.name, + ) + ret.append(False) + else: + ret.append(True) + return ret + sorted_guards = sorted(guards or (), key=Guard.sort_key) builder, guard_manager = self.build_guards( sorted_guards, @@ -2922,7 +2967,7 @@ def make_guard_filter_entry(guard): if guard_manager_testing_hook_fn is not None: guard_manager_testing_hook_fn( - self.guard_manager, output_graph.local_scope + self.guard_manager, output_graph.local_scope, builder ) # NB for developers: n_iters is chosen to be 1 to prevent excessive diff --git a/torch/_dynamo/package.py b/torch/_dynamo/package.py index 5bf63c2544cd..be750d41a1dc 100644 --- a/torch/_dynamo/package.py +++ b/torch/_dynamo/package.py @@ -32,6 +32,7 @@ from torch.compiler._cache import CacheArtifactFactory from .bytecode_transformation import get_code_keys +from .utils import dynamo_timed, increment_frame logger = logging.getLogger(__name__) @@ -218,7 +219,7 @@ def initialize( assert not self._initialized self._inlined_sources = set() - self._innermost_fn = innermost_fn(fn) + self._innermost_fn = innermost_fn(fn) # type: ignore[assignment] assert self._innermost_fn is not None if dynamo is not None: assert isinstance(dynamo, _DynamoCacheEntry) @@ -379,60 +380,87 @@ def install(self, backends: dict[_BackendId, Any]) -> None: 3. Install the precompiled cache entries to ExtraStates on the code object. """ from torch._C._dynamo.eval_frame import _load_precompile_entry + from torch._dynamo.convert_frame import get_compile_id + from torch._guards import compile_context, CompileContext from .output_graph import get_builtins_dict self.uninstall() - for code, entry in self._codes.items(): - module = sys.modules[entry.python_module] - for alias, module_name in entry.import_sources.items(): - self._install_global( - module, alias, importlib.import_module(module_name) - ) - for function_name in entry.function_names: - fn = types.FunctionType(code, module.__dict__, function_name) - self._install_global(module, function_name, fn) - for backend_id in entry.backend_ids: - if backend_id not in backends: - raise RuntimeError( - f"Backend {backend_id} is not found in the given backends" + # Each code represents a new compile frame + # recompiles on the same frame are all saved + # under the same cache entry, so we don't have recompile ids + # i.e. If cold start had 0/0, 0/1, 1/0, 1/1, these would be + # collapsed into 0/0, 1/0 on warm. + increment_frame() + compile_id = get_compile_id(frame_state={}) + with ( + compile_context(CompileContext(compile_id)), + dynamo_timed( + "_compile.compile_inner", + phase_name="entire_frame_compile", + dynamo_compile_column_us="dynamo_cumulative_compile_time_us", + # TODO: save all relevant compilation metrics + metadata={ + "frame_key": str(torch._dynamo.utils.curr_frame), + "co_name": code.co_name, + "co_filename": code.co_filename, + "co_firstlineno": code.co_firstlineno, + }, + ), + ): + module = sys.modules[entry.python_module] + for alias, module_name in entry.import_sources.items(): + self._install_global( + module, alias, importlib.import_module(module_name) ) - backend = backends[backend_id] - self._install_global( - module, - backend_id, - torch._dynamo.disable(backend), - ) + for function_name in entry.function_names: + fn = types.FunctionType(code, module.__dict__, function_name) + self._install_global(module, function_name, fn) + for backend_id in entry.backend_ids: + if backend_id not in backends: + raise RuntimeError( + f"Backend {backend_id} is not found in the given backends" + ) + with dynamo_timed( + "after_deserialization", phase_name="backend_compile" + ): + backend = backends[backend_id].after_deserialization() + self._install_global( + module, + backend_id, + torch._dynamo.disable(backend), + ) - for code, entry in self._codes.items(): - for guarded_code in entry.guarded_codes: - guards_state = pickle.loads(guarded_code.guards_state) - runtime_global_scope = sys.modules[entry.python_module].__dict__ - # The installed builtins dict might be absent from the runtime - # while loading guards. Populate it if it's missing. - if ( - builtin_dict_name - := guards_state.output_graph.name_of_builtins_dict_key_in_fglobals - ): - builtins_dict = get_builtins_dict(runtime_global_scope) - if builtin_dict_name in runtime_global_scope: - assert runtime_global_scope[builtin_dict_name] is builtins_dict - else: - runtime_global_scope[builtin_dict_name] = builtins_dict - assert isinstance(guards_state, torch._dynamo.guards.GuardsState) - check_fn_manager = torch._dynamo.guards.CheckFunctionManager( - code, - guards_state.output_graph, - guards_serialization_mode="load", - shape_code_parts=guards_state.shape_code_parts, - runtime_global_scope=runtime_global_scope, - ) - _load_precompile_entry( - code, - check_fn_manager.guard_manager, - SerializedCode.to_code_object(guarded_code.dynamo_code), - ) + for guarded_code in entry.guarded_codes: + guards_state = pickle.loads(guarded_code.guards_state) + runtime_global_scope = sys.modules[entry.python_module].__dict__ + # The installed builtins dict might be absent from the runtime + # while loading guards. Populate it if it's missing. + if ( + builtin_dict_name + := guards_state.output_graph.name_of_builtins_dict_key_in_fglobals + ): + builtins_dict = get_builtins_dict(runtime_global_scope) + if builtin_dict_name in runtime_global_scope: + assert ( + runtime_global_scope[builtin_dict_name] is builtins_dict + ) + else: + runtime_global_scope[builtin_dict_name] = builtins_dict + assert isinstance(guards_state, torch._dynamo.guards.GuardsState) + check_fn_manager = torch._dynamo.guards.CheckFunctionManager( + code, + guards_state.output_graph, + guards_serialization_mode="load", + shape_code_parts=guards_state.shape_code_parts, + runtime_global_scope=runtime_global_scope, + ) + _load_precompile_entry( + code, + check_fn_manager.guard_manager, + SerializedCode.to_code_object(guarded_code.dynamo_code), + ) def cache_entry(self) -> _DynamoCacheEntry: self.validate() @@ -556,7 +584,7 @@ def load_cache_entry( PrecompileContext.record_artifact( backend.type(), key=backend.key, content=backend.content ) - backend_content[backend_id] = backend.after_deserialization() + backend_content[backend_id] = backend return cache_entry, backend_content @@ -683,7 +711,8 @@ def load( path = os.path.join(self.path_prefix, key) if os.path.exists(path): try: - return super().load_cache_entry(key) + result = super().load_cache_entry(key) + return result except Exception as e: logger.warning("Failed to load package from path %s: %s", path, str(e)) return None diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 936d1c62e9e6..db2493f26caf 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -30,6 +30,7 @@ operator as operator, os as os, pytree as pytree, + struct as struct, sys as sys, ) diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index f60aa57a5d40..f306d47ba5f8 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -19,6 +19,7 @@ "operator", "os", "pytree", + "struct", "sys", "fx", "tensor", diff --git a/torch/_dynamo/polyfills/struct.py b/torch/_dynamo/polyfills/struct.py new file mode 100644 index 000000000000..f4522a12f732 --- /dev/null +++ b/torch/_dynamo/polyfills/struct.py @@ -0,0 +1,27 @@ +""" +Python polyfills for struct +""" + +from __future__ import annotations + +import struct +from typing import Any +from typing_extensions import Buffer + +from ..decorators import substitute_in_graph + + +__all__ = [ + "pack", + "unpack", +] + + +@substitute_in_graph(struct.pack, can_constant_fold_through=True) # type: ignore[arg-type] +def pack(fmt: bytes | str, /, *v: Any) -> bytes: + return struct.pack(fmt, *v) + + +@substitute_in_graph(struct.unpack, can_constant_fold_through=True) # type: ignore[arg-type] +def unpack(format: bytes | str, buffer: Buffer, /) -> tuple[Any, ...]: + return struct.unpack(format, buffer) diff --git a/torch/_dynamo/precompile_context.py b/torch/_dynamo/precompile_context.py index 6bb42bb34bc3..040f54ce70db 100644 --- a/torch/_dynamo/precompile_context.py +++ b/torch/_dynamo/precompile_context.py @@ -70,7 +70,8 @@ class PrecompileContext(CacheArtifactManager): The following artifact types are supported by PrecompileContext: - BundledAOTAutogradCacheArtifact - - CodeStateArtifact (from torch._dynamo.package once available) + - DynamoCodeStateArtifact + - AutotuneCacheArtifact (regular autotune results, same as Megacache) """ # Protected by the compile_lock @@ -149,8 +150,12 @@ def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo: artifacts_by_key = {} cache_info = CacheInfo() for artifact in chain(*artifacts.values()): + if artifact.type() == "autotune": + # Populate autotune cache artifacts + artifact.populate_cache() + else: + artifacts_by_key[artifact.key] = artifact cache_info.add(artifact) - artifacts_by_key[artifact.key] = artifact from torch._dynamo.package import _BackendId, DynamoCache diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index b1327473c909..cdbc1fcda037 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Utilities for reproducing and debugging issues in PyTorch's Dynamo AOT compilation. @@ -33,7 +31,7 @@ from collections.abc import Sequence from importlib import import_module from tempfile import TemporaryFile -from typing import Any, Callable, TYPE_CHECKING, Union +from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union from typing_extensions import Unpack import torch @@ -157,7 +155,7 @@ def deferred_for_real_inputs( with config.patch(repro_after=None): return inner_debug_fn(real_inputs) - def inner_debug_fn(real_inputs): + def inner_debug_fn(real_inputs: Sequence["InputType"]) -> Any: """ Aot Autograd fw_compiler and bw_compiler can have fake tensors. So, example_inputs can be fake tensors. We can call compiler_fn (which is @@ -186,7 +184,7 @@ def inner_debug_fn(real_inputs): ) failed = not same_two_models( gm, - inner_compiled_fn, + inner_compiled_fn, # type: ignore[arg-type] real_inputs, only_fwd=True, ignore_non_fp=config.repro_ignore_non_fp, @@ -250,7 +248,7 @@ def inner_debug_fn(real_inputs): # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -def maybe_fbcode_instructions(): +def maybe_fbcode_instructions() -> str: if is_fbcode(): extra_deps_formatted = "\n".join([f' "{dep}",' for dep in extra_deps]) if len(extra_deps_formatted) > 0: @@ -283,14 +281,14 @@ def maybe_fbcode_instructions(): def generate_compiler_repro_string( - gm, - args, + gm: torch.fx.GraphModule, + args: Sequence[Any], *, - stable_output=False, - save_dir=None, - stable_hash=False, - has_distributed_ops=False, -): + stable_output: bool = False, + save_dir: Optional[str] = None, + stable_hash: bool = False, + has_distributed_ops: bool = False, +) -> str: # Add distributed imports if needed distributed_imports = "" if has_distributed_ops: @@ -377,19 +375,19 @@ def generate_compiler_repro_string( def save_graph_repro( - fd, - gm, - args, - compiler_name, + fd: IO[Any], + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: str, *, - stable_output=False, - save_dir=None, - command="run", - accuracy=None, - tracing_mode=None, - check_str=None, - stable_hash=False, -): + stable_output: bool = False, + save_dir: Optional[str] = None, + command: str = "run", + accuracy: Optional[Union[str, bool]] = None, + tracing_mode: Optional[str] = None, + check_str: Optional[str] = None, + stable_hash: bool = False, +) -> None: if any( isinstance(arg, torch.fx.experimental._backward_state.BackwardState) for arg in args @@ -456,7 +454,13 @@ def save_graph_repro( fd.write("\n dist.destroy_process_group()\n") -def dump_compiler_graph_state(gm, args, compiler_name, *, accuracy=None): +def dump_compiler_graph_state( + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: str, + *, + accuracy: Optional[Union[str, bool]] = None, +) -> None: subdir = os.path.join(minifier_dir(), "checkpoints") if not os.path.exists(subdir): os.makedirs(subdir, exist_ok=True) @@ -484,7 +488,9 @@ def dump_compiler_graph_state(gm, args, compiler_name, *, accuracy=None): # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -def dump_to_minify(gm, args, compiler_name: str): +def dump_to_minify( + gm: torch.fx.GraphModule, args: Sequence[Any], compiler_name: str +) -> None: out = io.StringIO() # TODO: factor this out subdir = os.path.join(minifier_dir(), "checkpoints") @@ -495,15 +501,15 @@ def dump_to_minify(gm, args, compiler_name: str): def isolate_fails( - fx_g, - args, + fx_g: torch.fx.GraphModule, + args: Sequence[Any], compiler_name: str, - env=None, - save_dir=None, - accuracy=None, - tracing_mode=None, - check_str=None, -): + env: Optional[dict[str, Any]] = None, + save_dir: Optional[str] = None, + accuracy: Optional[Union[bool, str]] = None, + tracing_mode: Optional[str] = None, + check_str: Optional[str] = None, +) -> bool: if env is None: env = {} subdir = os.path.join(os.getcwd(), "isolate") @@ -559,14 +565,16 @@ def isolate_fails( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -def inductor_fails(fx_g, args, check_str=None): +def inductor_fails( + fx_g: torch.fx.GraphModule, args: Sequence[Any], check_str: Optional[str] = None +) -> bool: has_cuda = False for arg in args: if isinstance(arg, torch.Tensor) and arg.is_cuda: has_cuda = True break - def sync(): + def sync() -> None: if has_cuda: # Ensures that segfaults are surfaced torch.cuda.synchronize() @@ -596,14 +604,19 @@ def sync(): def inductor_accuracy_fails( - fx_g, args, check_str=None, *, require_fp64=False, ignore_non_fp=False -): + fx_g: torch.fx.GraphModule, + args: Sequence[Any], + check_str: Optional[str] = None, + *, + require_fp64: bool = False, + ignore_non_fp: bool = False, +) -> bool: from torch._inductor.compile_fx import compile_fx_inner return backend_aot_accuracy_fails( fx_g, - args, - compile_fx_inner, + args, # type: ignore[arg-type] + compile_fx_inner, # type: ignore[arg-type] require_fp64=require_fp64, ignore_non_fp=ignore_non_fp, ) @@ -617,7 +630,9 @@ def inductor_accuracy_fails( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -def repro_common(options, mod, load_args): +def repro_common( + options: Any, mod: nn.Module, load_args: Any +) -> tuple[torch.fx.GraphModule, Sequence[Any]]: # Invariant for graphs we generate with the repro script assert not any(mod.named_parameters()) for n, b in mod.named_buffers(): @@ -660,7 +675,7 @@ def repro_common(options, mod, load_args): return mod, args -ACCURACY_FAILS: dict[str, Callable[[nn.Module, Any], bool]] = { +ACCURACY_FAILS: dict[str, Callable[[torch.fx.GraphModule, Any], bool]] = { "": inductor_fails, # This might look inverted but it's not. strict_accuracy means "we will # minify any time we see anything that diverges", whereas accuracy is more @@ -673,7 +688,7 @@ def repro_common(options, mod, load_args): } -def repro_minifier_query(options, mod, load_args): +def repro_minifier_query(options: Any, mod: nn.Module, load_args: Any) -> None: mod, args = repro_common(options, mod, load_args) fail_fn = functools.partial( ACCURACY_FAILS[options.accuracy], @@ -685,7 +700,7 @@ def repro_minifier_query(options, mod, load_args): sys.exit(0) -def repro_minify(options, mod, load_args): +def repro_minify(options: Any, mod: nn.Module, load_args: Any) -> None: from functorch.compile import minifier mod, args = repro_common(options, mod, load_args) @@ -722,7 +737,7 @@ def repro_minify(options, mod, load_args): ) -def repro_analyze(options, mod, load_args): +def repro_analyze(options: Any, mod: nn.Module, load_args: Any) -> None: from torch._inductor.compile_fx import compile_fx_inner from torch._inductor.hooks import intermediate_hook @@ -740,7 +755,7 @@ def repro_analyze(options, mod, load_args): known_names = set() - def save_hook(name, val): + def save_hook(name: str, val: Any) -> None: known_names.add(name) if not options.skip_saving_inductor_intermediates: writer.write_tensor(os.path.join("inductor", name), val) @@ -757,10 +772,10 @@ def save_hook(name, val): tqdm(desc="Saving inductor intermediates", total=total) as pbar, ): assert not isinstance(compiled, str) - compiled(new_args) + compiled(new_args) # type: ignore[arg-type] assert not new_args - def compare_tuples(tuple1, tuple2): + def compare_tuples(tuple1: tuple[Any], tuple2: tuple[Any]) -> Optional[str]: diff_indices = [i for i in range(len(tuple1)) if tuple1[i] != tuple2[i]] diff_values = [(tuple1[i], tuple2[i]) for i in diff_indices] @@ -769,7 +784,7 @@ def compare_tuples(tuple1, tuple2): else: return " and ".join(f"{a} != {b}" for a, b in diff_values) - def check_hook(name, val): + def check_hook(name: str, val: Any) -> None: meta = writer.compute_tensor_metadata(val) meta2 = reader.read_tensor_metadata(os.path.join("inductor", name)) reason = compare_tuples(meta, meta2) @@ -783,15 +798,15 @@ def check_hook(name, val): intermediate_hook(check_hook), tqdm(desc="Checking inductor determinism", total=total) as pbar, ): - compiled(new_args) + compiled(new_args) # type: ignore[arg-type] assert not new_args class WriterInterp(fx.Interpreter): - def __init__(self, mod, subdir) -> None: + def __init__(self, mod: torch.nn.Module, subdir: str) -> None: super().__init__(mod) self.subdir = subdir - def run_node(self, n): + def run_node(self, n: torch.fx.Node) -> Any: r = super().run_node(n) name = n.name if name in known_names: @@ -802,13 +817,13 @@ def run_node(self, n): # NB: the module cast doesn't actually do anything, since there are no # parameters/buffers on the module if not options.skip_saving_float64_intermediates: - new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) + new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) # type: ignore[arg-type] with tqdm(desc="Saving float64 intermediates", total=total) as pbar: WriterInterp(new_mod, "float64").boxed_run(new_args) assert not new_args class ExactReaderInterp(fx.Interpreter): - def run_node(self, n): + def run_node(self, n: torch.fx.Node) -> Any: r = super().run_node(n) name = n.name if name in known_names: @@ -823,7 +838,7 @@ def run_node(self, n): # TODO: check eager determinism if not options.skip_check_deterministic: - new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) + new_mod, new_args = cast_to_fp64(copy.deepcopy(mod), clone_inputs(args)) # type: ignore[arg-type] with tqdm(desc="Checking float64 determinism", total=total) as pbar: ExactReaderInterp(new_mod).boxed_run(new_args) assert not new_args @@ -831,7 +846,7 @@ def run_node(self, n): # Now that we've saved everything, interp through the eager graph # and do comparisons class ReaderInterp(fx.Interpreter): - def run_node(self, n): + def run_node(self, n: torch.fx.Node) -> Any: r = super().run_node(n) name = n.name if name in known_names: @@ -839,7 +854,7 @@ def run_node(self, n): float64 = reader.read_tensor(os.path.join("float64", name)) logged = False - def log_error(msg, *args): + def log_error(msg: str, *args: Any) -> None: nonlocal logged logged = True pbar.write(f"DIVERGED at {name}: {msg % args}") @@ -861,12 +876,14 @@ def log_error(msg, *args): assert not args -def repro_get_args(options, mod, load_args): +def repro_get_args( + options: Any, mod: nn.Module, load_args: Any +) -> tuple[torch.fx.GraphModule, list[Any]]: mod, args = repro_common(options, mod, load_args) - return mod, args + return mod, args # type: ignore[return-value] -def repro_run(options, mod, load_args): +def repro_run(options: Any, mod: nn.Module, load_args: Any) -> None: from torch._inductor.compile_fx import compile_fx_inner mod, args = repro_common(options, mod, load_args) @@ -881,7 +898,7 @@ def repro_run(options, mod, load_args): # seems counterintuitive if not same_two_models( mod, - compiled, + compiled, # type: ignore[arg-type] args, only_fwd=True, ignore_non_fp=config.repro_ignore_non_fp, @@ -903,17 +920,17 @@ def repro_run(options, mod, load_args): # TODO: lazily load the inputs or something, rather than cloning them def run_repro( - mod, - load_args, + mod: nn.Module, + load_args: Any, *, - command="run", + command: str = "run", accuracy: Union[bool, str] = "", - save_dir=None, - tracing_mode=None, - patch_code=None, - check_str=None, - **kwargs, -): + save_dir: Optional[str] = None, + tracing_mode: Optional[str] = None, + patch_code: Optional[str] = None, + check_str: Optional[str] = None, + **kwargs: Any, +) -> Any: for k in kwargs: log.warning( "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", @@ -946,7 +963,7 @@ def run_repro( formatter_class=argparse.RawTextHelpFormatter, ) - def common_flags(parser): + def common_flags(parser: argparse.ArgumentParser) -> None: accuracy_group = parser.add_mutually_exclusive_group() accuracy_group.add_argument( "--no-accuracy", diff --git a/torch/_dynamo/repro/after_dynamo.py b/torch/_dynamo/repro/after_dynamo.py index 80191f2d6cef..898946d6f89f 100644 --- a/torch/_dynamo/repro/after_dynamo.py +++ b/torch/_dynamo/repro/after_dynamo.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Utilities for reproducing and debugging issues in Dynamo after graph capture. @@ -26,12 +24,12 @@ import shutil import sys import textwrap +from collections.abc import Sequence from importlib import import_module -from typing import Union +from typing import Any, Callable, Optional, Union import torch import torch.fx as fx -from torch._dynamo.backends.registry import CompiledFn from torch._dynamo.debug_utils import ( AccuracyError, backend_accuracy_fails, @@ -53,7 +51,7 @@ from torch.hub import tqdm from .. import config -from ..backends.registry import lookup_backend, register_debug_backend +from ..backends.registry import CompilerFn, lookup_backend, register_debug_backend from ..debug_utils import clone_inputs_retaining_gradness @@ -68,7 +66,11 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -def _accuracy_fails(gm, example_inputs, compiler_fn): +def _accuracy_fails( + gm: torch.fx.GraphModule, + example_inputs: Sequence[Any], + compiler_fn: Callable[[torch.fx.GraphModule, list[Any]], torch.fx.GraphModule], +) -> bool: return backend_accuracy_fails( gm, example_inputs, @@ -79,29 +81,33 @@ def _accuracy_fails(gm, example_inputs, compiler_fn): class WrapBackendDebug: - def __init__(self, unconfigured_compiler_fn, compiler_name: str) -> None: + def __init__( + self, unconfigured_compiler_fn: CompilerFn, compiler_name: Optional[str] + ) -> None: functools.wraps(unconfigured_compiler_fn)(self) - self._torchdynamo_orig_backend = unconfigured_compiler_fn # type: ignore[attr-defined] + self._torchdynamo_orig_backend = unconfigured_compiler_fn self._compiler_name = compiler_name if hasattr(unconfigured_compiler_fn, "__name__"): self.__name__ = unconfigured_compiler_fn.__name__ if hasattr(unconfigured_compiler_fn, "compiler_name"): - self.__name__ = unconfigured_compiler_fn.compiler_name + self.__name__ = unconfigured_compiler_fn.compiler_name # type: ignore[attr-defined] if hasattr(unconfigured_compiler_fn, "get_compiler_config"): self.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined] - def __call__(self, gm, example_inputs, **kwargs): + def __call__( + self, gm: torch.fx.GraphModule, example_inputs: list[Any], **kwargs: Any + ) -> torch.fx.GraphModule: compiler_fn = functools.partial(self._torchdynamo_orig_backend, **kwargs) assert config.repro_after in ("dynamo", "aot", None) if config.repro_after == "dynamo": - def add_paths(exc): - exc.minifier_path = os.path.join(minifier_dir(), "minifier_launcher.py") + def add_paths(exc: Exception) -> None: + exc.minifier_path = os.path.join(minifier_dir(), "minifier_launcher.py") # type: ignore[attr-defined] if use_buck: - exc.buck_command = " ".join( + exc.buck_command = " ".join( # type: ignore[attr-defined] BUCK_CMD_PREFIX - + [BuckTargetWriter(exc.minifier_path).cmd_line_path] + + [BuckTargetWriter(exc.minifier_path).cmd_line_path] # type: ignore[attr-defined] ) if config.repro_level == 3: @@ -111,7 +117,7 @@ def add_paths(exc): if config.repro_level == 4: # Check Accuracy compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs) - if _accuracy_fails(gm, example_inputs, compiler_fn): + if _accuracy_fails(gm, example_inputs, compiler_fn): # type: ignore[arg-type] log.warning( "Accuracy failed for the TorchDynamo produced graph. Creating script to minify the error." ) @@ -126,7 +132,7 @@ def add_paths(exc): else: try: compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs) - run_fwd_maybe_bwd(compiled_gm, example_inputs) + run_fwd_maybe_bwd(compiled_gm, example_inputs) # type: ignore[arg-type] except Exception as exc: log.warning( "Compiled Fx GraphModule failed. Creating script to minify the error." @@ -149,10 +155,12 @@ def add_paths(exc): else: compiled_gm = compiler_fn(gm, example_inputs) - return compiled_gm + return compiled_gm # type: ignore[return-value] -def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str): +def wrap_backend_debug( + unconfigured_compiler_fn: CompilerFn, compiler_name: Optional[str] +) -> WrapBackendDebug: """ A minifier decorator that wraps the TorchDynamo produced Fx graph modules. As opposed to wrap_compiler_debug, this wrapper intercepts at the @@ -170,15 +178,15 @@ def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str): def generate_dynamo_fx_repro_string( - gm, - args, - compiler_name, - check_accuracy=False, + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: Optional[str], + check_accuracy: bool = False, *, - stable_output=False, - save_dir=None, - command="run", -): + stable_output: bool = False, + save_dir: Optional[str] = None, + command: str = "run", +) -> str: """ Generate a repro string for backend-agnostic minified version. """ @@ -225,7 +233,12 @@ def generate_dynamo_fx_repro_string( ) -def dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy=False): +def dump_backend_repro_as_file( + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: Optional[str], + check_accuracy: bool = False, +) -> None: """ Saves the repro to a repro.py file """ @@ -253,7 +266,12 @@ def dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy=False): shutil.copyfile(file_name, latest_repro) -def dump_backend_state(gm, args, compiler_name, check_accuracy=False): +def dump_backend_state( + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: Optional[str], + check_accuracy: bool = False, +) -> None: """ Dumps the dynamo graph to repro the issue. 1) It tries to convert Fx GraphModule to a string. If we can, it writes to a @@ -271,7 +289,9 @@ def dump_backend_state(gm, args, compiler_name, check_accuracy=False): # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -def dump_to_minify_after_dynamo(gm, args, compiler_name): +def dump_to_minify_after_dynamo( + gm: torch.fx.GraphModule, args: Sequence[Any], compiler_name: Optional[str] +) -> None: # TODO: factor this out subdir = os.path.join(minifier_dir(), "checkpoints") if not os.path.exists(subdir): @@ -295,8 +315,8 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name): @register_debug_backend # type: ignore[arg-type] def dynamo_minifier_backend( - gm: fx.GraphModule, example_inputs, compiler_name: CompiledFn -): + gm: fx.GraphModule, example_inputs: Sequence[Any], compiler_name: Optional[str] +) -> fx.GraphModule: from functorch.compile import minifier compiler_fn = lookup_backend(compiler_name) @@ -336,7 +356,9 @@ def dynamo_minifier_backend( @register_debug_backend # type: ignore[arg-type] -def dynamo_accuracy_minifier_backend(gm, example_inputs, compiler_name): +def dynamo_accuracy_minifier_backend( + gm: fx.GraphModule, example_inputs: Sequence[Any], compiler_name: Optional[str] +) -> fx.GraphModule: from functorch.compile import minifier compiler_fn = lookup_backend(compiler_name) @@ -366,7 +388,12 @@ def dynamo_accuracy_minifier_backend(gm, example_inputs, compiler_name): return gm -def backend_fails(gm, example_inputs, compiler_fn, orig_failure): +def backend_fails( + gm: fx.GraphModule, + example_inputs: Sequence[Any], + compiler_fn: CompilerFn, + orig_failure: Sequence[Any], +) -> bool: """ Minifier uses this function to identify if the minified graph module fails with the same error. @@ -383,8 +410,8 @@ def backend_fails(gm, example_inputs, compiler_fn, orig_failure): try: # Run the original gm to check eager validity run_fwd_maybe_bwd(gm, clone_inputs_retaining_gradness(example_inputs)) - compiled_gm = compiler_fn(gm, example_inputs) - run_fwd_maybe_bwd(compiled_gm, clone_inputs_retaining_gradness(example_inputs)) + compiled_gm = compiler_fn(gm, example_inputs) # type: ignore[arg-type] + run_fwd_maybe_bwd(compiled_gm, clone_inputs_retaining_gradness(example_inputs)) # type: ignore[arg-type] except Exception as e: new_failure = str(e) if SequenceMatcher(None, orig_failure, new_failure).ratio() > 0.5: @@ -397,7 +424,7 @@ def backend_fails(gm, example_inputs, compiler_fn, orig_failure): # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -def run_load_args(options, mod, load_args): +def run_load_args(options: Any, mod: torch.nn.Module, load_args: Any) -> list[Any]: if not hasattr(load_args, "_version"): log.warning( "load_args does not have a _version attribute, please file a bug to PyTorch " @@ -423,7 +450,7 @@ def run_load_args(options, mod, load_args): return args -def repro_minify(options, mod, load_args): +def repro_minify(options: Any, mod: torch.nn.Module, load_args: Any) -> None: args = run_load_args(options, mod, load_args) # Setup debug minifier compiler @@ -450,20 +477,20 @@ def repro_minify(options, mod, load_args): opt_mod(*args) -def repro_run(options, mod, load_args): +def repro_run(options: Any, mod: torch.nn.Module, load_args: Any) -> None: opt_mod = torch._dynamo.optimize(options.backend)(mod) if options.accuracy != "": mod.eval() - opt_mod.eval() + opt_mod.eval() # type: ignore[union-attr] with torch.amp.autocast("cuda", enabled=options.autocast): # TODO: disable clone args = run_load_args(options, mod, load_args) - assert same_two_models(mod, mod, args), "Eager itself failed" + assert same_two_models(mod, mod, args), "Eager itself failed" # type: ignore[arg-type] if not same_two_models( - mod, - opt_mod, + mod, # type: ignore[arg-type] + opt_mod, # type: ignore[arg-type] args, only_fwd=config.repro_forward_only, ignore_non_fp=config.repro_ignore_non_fp, @@ -472,26 +499,29 @@ def repro_run(options, mod, load_args): else: with torch.amp.autocast("cuda", enabled=options.autocast): args = run_load_args(options, mod, load_args) - run_fwd_maybe_bwd(mod, args, only_fwd=options.only_fwd, disable_clone=True) + run_fwd_maybe_bwd(mod, args, only_fwd=options.only_fwd, disable_clone=True) # type: ignore[arg-type] del args args = run_load_args(options, mod, load_args) run_fwd_maybe_bwd( - opt_mod, args, only_fwd=options.only_fwd, disable_clone=True + opt_mod, # type: ignore[arg-type] + args, + only_fwd=options.only_fwd, + disable_clone=True, # type: ignore[arg-type] ) def run_repro( - mod, - load_args, + mod: torch.nn.Module, + load_args: Any, *, - command="run", + command: str = "run", accuracy: Union[bool, str] = "", - save_dir=None, - autocast=False, - backend="inductor", - **kwargs, -): + save_dir: Optional[str] = None, + autocast: bool = False, + backend: str = "inductor", + **kwargs: Any, +) -> None: for k in kwargs: log.warning( "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", @@ -517,7 +547,7 @@ def run_repro( formatter_class=argparse.RawTextHelpFormatter, ) - def common_flags(parser): + def common_flags(parser: argparse.ArgumentParser) -> None: accuracy_group = parser.add_mutually_exclusive_group() accuracy_group.add_argument( "--no-accuracy", diff --git a/torch/_dynamo/repro/aoti.py b/torch/_dynamo/repro/aoti.py index c3fab6bd086a..808383e68e51 100644 --- a/torch/_dynamo/repro/aoti.py +++ b/torch/_dynamo/repro/aoti.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Utilities for debugging and reproducing issues in Ahead of Time with Inductor (AOTI) compilation. @@ -26,8 +24,9 @@ import shutil import sys import textwrap +from collections.abc import Sequence from importlib import import_module -from typing import Any, Optional, Union +from typing import Any, IO, Optional, Union import torch from torch._dynamo.debug_utils import ( @@ -54,7 +53,7 @@ class AOTIMinifierError(Exception): - def __init__(self, original_exception): + def __init__(self, original_exception: Union[str, Exception]) -> None: additional_message = "This error is caused by a bug in the AOTI minifier, please report a bug to PyTorch" full_message = f"{additional_message}: {str(original_exception)}" super().__init__(full_message) @@ -66,7 +65,7 @@ def dump_to_minify( compiler_name: str, command: str = "minify", options: Optional[dict[str, Any]] = None, -): +) -> None: """ If command is "minify": Dump exported_program to `debug_dir/minifier/minifier_launcher.py`, with minify command. @@ -111,8 +110,8 @@ def dump_to_minify( log.warning("No write permissions for %s", file_name) -def get_module_string(gm): - def _convert_to_comment(s_): +def get_module_string(gm: torch.fx.GraphModule) -> str: + def _convert_to_comment(s_: str) -> str: s = s_.split("\n") if len(s) == 1: return "# " + s_ @@ -132,21 +131,21 @@ def _convert_to_comment(s_): def save_graph_repro_ep( - fd, - compiler_name, + fd: IO[Any], + compiler_name: str, *, exported_program: Optional[ExportedProgram] = None, gm: Optional[torch.nn.Module] = None, args: Optional[tuple[Any]] = None, config_patches: Optional[dict[str, str]] = None, - stable_output=False, - save_dir=None, - command="run", - accuracy=None, - check_str=None, - module_in_comment=False, - strict=False, -): + stable_output: bool = False, + save_dir: Optional[str] = None, + command: str = "run", + accuracy: Optional[Union[str, bool]] = None, + check_str: Optional[str] = None, + module_in_comment: bool = False, + strict: bool = False, +) -> None: # Save graph for reproducing the error. # Either exported_program or gm will be saved, depending on which one is defined. # Only one of exported_program and gm should be defined. @@ -166,7 +165,7 @@ def save_graph_repro_ep( gm = exported_program.module() # save a graph preview using gm - module_string = get_module_string(gm) + module_string = get_module_string(gm) # type: ignore[arg-type] fd.write(module_string) # save a graph repro using exported_program @@ -190,14 +189,14 @@ def save_graph_repro_ep( def dump_compiler_graph_state( - gm, - args, - compiler_name, + gm: torch.fx.GraphModule, + args: Sequence[Any], + compiler_name: str, *, - config_patches=None, - accuracy=None, - strict=False, -): + config_patches: Optional[dict[str, str]] = None, + accuracy: Optional[Union[str, bool]] = None, + strict: bool = False, +) -> None: subdir = os.path.join(minifier_dir(), "checkpoints") if not os.path.exists(subdir): os.makedirs(subdir, exist_ok=True) @@ -234,12 +233,12 @@ def dump_compiler_graph_state( def generate_compiler_repro_exported_program( - exported_program, + exported_program: ExportedProgram, *, options: Optional[dict[str, str]] = None, - stable_output=False, - save_dir=None, -): + stable_output: bool = False, + save_dir: Optional[str] = None, +) -> str: model_str = textwrap.dedent( f""" {generate_env_vars_string(stable_output=stable_output)} @@ -261,8 +260,10 @@ def generate_compiler_repro_exported_program( if hasattr(torch.version, "git_version"): model_str += f"# torch git version: {torch.version.git_version}\n\n\n" model_str += _cuda_system_info_comment() - - ep_path = os.path.join(save_dir, "exported_program.pt2") + if save_dir: + ep_path = os.path.join(save_dir, "exported_program.pt2") + else: + ep_path = "exported_program.pt2" torch.export.save(exported_program, ep_path) model_str += f"exported_program = torch.export.load('{ep_path}')\n" @@ -271,7 +272,7 @@ def generate_compiler_repro_exported_program( return model_str -def repro_load_args(load_args, save_dir): +def repro_load_args(load_args: Any, save_dir: Optional[str]) -> tuple[Any]: if not hasattr(load_args, "_version"): log.warning( "load_args does not have a _version attribute, please file a bug to PyTorch " @@ -297,19 +298,29 @@ def repro_load_args(load_args, save_dir): return tuple(args) -def repro_common(options, exported_program): +def repro_common( + options: Any, exported_program: ExportedProgram +) -> tuple[torch.fx.GraphModule, Any, Any]: torch._inductor.config.generate_intermediate_hooks = True mod = exported_program.module() args, kwargs = exported_program.example_inputs - return mod, args, kwargs + return mod, args, kwargs # type: ignore[return-value] -def repro_get_args(options, exported_program, config_patches): +def repro_get_args( + options: Any, + exported_program: ExportedProgram, + config_patches: Optional[dict[str, Any]], +) -> tuple[torch.fx.GraphModule, Any, Any]: mod, args, kwargs = repro_common(options, exported_program) return mod, args, kwargs -def repro_run(options, exported_program, config_patches): +def repro_run( + options: Any, + exported_program: ExportedProgram, + config_patches: Optional[dict[str, Any]], +) -> None: from torch._inductor import _aoti_compile_and_package_inner gm, args, kwargs = repro_common(options, exported_program) @@ -337,7 +348,10 @@ def repro_run(options, exported_program, config_patches): def export_for_aoti_minifier( - gm, tuple_inputs, strict=False, skip_export_error=True + gm: torch.nn.Module, + tuple_inputs: tuple[Any], + strict: bool = False, + skip_export_error: bool = True, ) -> Optional[torch.nn.Module]: # Some graphs cannot be used for AOTI/export (illegal graphs), these should be # considered as graphs that don't fail in the minifier, so the minifier keeps searching. @@ -372,7 +386,11 @@ def export_for_aoti_minifier( return None -def repro_minify(options, exported_program, config_patches): +def repro_minify( + options: Any, + exported_program: ExportedProgram, + config_patches: Optional[dict[str, Any]], +) -> None: from functorch.compile import minifier from torch._inductor import _aoti_compile_and_package_inner from torch._inductor.compile_fx import _aoti_flatten_inputs @@ -397,7 +415,11 @@ def repro_minify(options, exported_program, config_patches): need_sync = True break - def module_fails(gm, flat_example_inputs, check_str=None): + def module_fails( + gm: torch.fx.GraphModule, + flat_example_inputs: list[Any], + check_str: Optional[str] = None, + ) -> bool: # Need to export first so the in_spec and out_spec are populated tuple_inputs = tuple(flat_example_inputs) gm = export_for_aoti_minifier( @@ -447,18 +469,18 @@ def module_fails(gm, flat_example_inputs, check_str=None): def run_repro( - exported_program, + exported_program: ExportedProgram, *, config_patches: Optional[dict[str, str]] = None, - command="run", + command: str = "run", accuracy: Union[bool, str] = "", - save_dir=None, - tracing_mode=None, - check_str=None, - minifier_export_mode="python", - skip_export_error=True, - **more_kwargs, -): + save_dir: Optional[str] = None, + tracing_mode: Optional[str] = None, + check_str: Optional[str] = None, + minifier_export_mode: str = "python", + skip_export_error: bool = True, + **more_kwargs: Any, +) -> Any: for k in more_kwargs: log.warning( "Unrecognized kwarg %s; perhaps this repro was made on a newer version of PyTorch", @@ -486,7 +508,7 @@ def run_repro( formatter_class=argparse.RawTextHelpFormatter, ) - def common_flags(parser): + def common_flags(parser: argparse.ArgumentParser) -> None: accuracy_group = parser.add_mutually_exclusive_group() accuracy_group.add_argument( "--no-accuracy", diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index beaaa77671e1..0bd0a1b0ab2a 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ This module provides functionality for resuming Python execution at specific points in code, primarily used by PyTorch Dynamo for control flow handling and optimization. It implements @@ -19,7 +17,9 @@ import dataclasses import sys import types -from typing import Any, cast, Optional +from collections.abc import Iterable +from contextlib import AbstractContextManager +from typing import Any, Callable, cast, Optional from .bytecode_transformation import ( bytecode_from_template, @@ -52,7 +52,7 @@ IS_TRACING_RESUME_PROLOGUE_VARNAME = "__is_tracing_resume_prologue" -def _initial_push_null(insts): +def _initial_push_null(insts: list[Instruction]) -> None: if sys.version_info >= (3, 11): insts.append(create_instruction("PUSH_NULL")) if sys.version_info < (3, 13): @@ -60,7 +60,11 @@ def _initial_push_null(insts): # Generates bytecode from template and splits the code where LOAD_FAST dummy is present. -def _bytecode_from_template_with_split(template, stack_index, varname_map=None): +def _bytecode_from_template_with_split( + template: Callable[..., Any], + stack_index: int, + varname_map: Optional[dict[str, Any]] = None, +) -> tuple[list[Instruction], list[Instruction]]: template_code = bytecode_from_template(template, varname_map=varname_map) template_code.append(create_instruction("POP_TOP")) @@ -78,7 +82,7 @@ def _bytecode_from_template_with_split(template, stack_index, varname_map=None): ), (None, None), ) - assert dummy_idx is not None + assert dummy_idx is not None and dummy_inst is not None # replace LOAD_FAST dummy with first NOP marking exception area overwrite_instruction(dummy_inst, [create_instruction("NOP")]) @@ -90,7 +94,7 @@ def _bytecode_from_template_with_split(template, stack_index, varname_map=None): return template_code[: dummy_idx + 1], template_code[dummy_idx + 1 :] -def _try_except_tf_mode_template(dummy, stack_var_name): +def _try_except_tf_mode_template(dummy: Any, stack_var_name: Any) -> None: # NOTE: Make sure this name matches what is generated by symbolic_convert:import_source # on torch._dynamo.utils. global __import_torch_dot__dynamo_dot_utils @@ -108,7 +112,9 @@ class ReenterWith: stack_index: int target_values: Optional[tuple[Any, ...]] = None - def try_except_torch_function_mode(self, code_options, cleanup: list[Instruction]): + def try_except_torch_function_mode( + self, code_options: dict[str, Any], cleanup: list[Instruction] + ) -> list[Instruction]: """ Codegen based off of: try: @@ -130,7 +136,9 @@ def try_except_torch_function_mode(self, code_options, cleanup: list[Instruction # If we do not want to destroy the stack, we can do the same thing as a # `SETUP_WITH` block, only that we store the context manager in a local_symbol - def try_finally(self, code_options, cleanup: list[Instruction]): + def try_finally( + self, code_options: dict[str, Any], cleanup: list[Instruction] + ) -> list[Instruction]: """ Codegen based off of: load args @@ -161,7 +169,7 @@ def try_finally(self, code_options, cleanup: list[Instruction]): ] ) - def _template(ctx, dummy): + def _template(ctx: AbstractContextManager[Any], dummy: Any) -> None: ctx.__enter__() try: dummy @@ -174,7 +182,9 @@ def _template(ctx, dummy): cleanup[:] = epilogue + cleanup return create_ctx + setup_try_finally - def __call__(self, code_options, cleanup): + def __call__( + self, code_options: dict[str, Any], cleanup: list[Instruction] + ) -> tuple[list[Instruction], Optional[Instruction]]: """ Codegen based off of: with ctx(args): @@ -194,7 +204,7 @@ def __call__(self, code_options, cleanup): ] ) - def _template(ctx, dummy): + def _template(ctx: AbstractContextManager[Any], dummy: Any) -> None: with ctx: dummy @@ -242,7 +252,11 @@ class ResumeFunctionMetadata: block_target_offset_remap: Optional[dict[int, int]] = None -def _filter_iter(l1, l2, cond): +def _filter_iter( + l1: Iterable[Any], + l2: Iterable[Any], + cond: Callable[[Any, Any], bool], +) -> list[Any]: """ Two-pointer conditional filter. e.g. _filter_iter(insts, sorted_offsets, lambda i, o: i.offset == o) @@ -261,7 +275,7 @@ def _filter_iter(l1, l2, cond): return res -def _load_tuple_and_call(tup): +def _load_tuple_and_call(tup: tuple[Any, ...]) -> list[Instruction]: insts: list[Instruction] = [] _initial_push_null(insts) insts.extend(create_load_const(val) for val in tup) @@ -274,7 +288,7 @@ class ContinueExecutionCache: generated_code_metadata = ExactWeakKeyDictionary() @classmethod - def lookup(cls, code, lineno, *key): + def lookup(cls, code: types.CodeType, lineno: int, *key: Any) -> types.CodeType: if code not in cls.cache: cls.cache[code] = {} key = tuple(key) @@ -285,8 +299,8 @@ def lookup(cls, code, lineno, *key): @classmethod def generate( cls, - code, - lineno, + code: types.CodeType, + lineno: int, offset: int, setup_fn_target_offsets: tuple[int, ...], # only used in Python 3.11+ nstack: int, @@ -321,7 +335,9 @@ def generate( is_py311_plus = sys.version_info >= (3, 11) meta = ResumeFunctionMetadata(code) - def update(instructions: list[Instruction], code_options: dict[str, Any]): + def update( + instructions: list[Instruction], code_options: dict[str, Any] + ) -> None: meta.instructions = copy.deepcopy(instructions) args = [f"___stack{i}" for i in range(nstack)] @@ -479,7 +495,7 @@ def update(instructions: list[Instruction], code_options: dict[str, Any]): inst.exn_tab_entry and inst.exn_tab_entry.target in old_hook_target_remap ): - inst.exn_tab_entry.target = old_hook_target_remap[ + inst.exn_tab_entry.target = old_hook_target_remap[ # type: ignore[assignment] inst.exn_tab_entry.target ] @@ -491,7 +507,7 @@ def update(instructions: list[Instruction], code_options: dict[str, Any]): return new_code @staticmethod - def unreachable_codes(code_options) -> list[Instruction]: + def unreachable_codes(code_options: dict[str, Any]) -> list[Instruction]: """Codegen a `raise None` to make analysis work for unreachable code""" return [ create_load_const(None), @@ -500,8 +516,13 @@ def unreachable_codes(code_options) -> list[Instruction]: @classmethod def generate_based_on_original_code_object( - cls, code, lineno, offset: int, setup_fn_target_offsets: tuple[int, ...], *args - ): + cls, + code: types.CodeType, + lineno: int, + offset: int, + setup_fn_target_offsets: tuple[int, ...], + *args: Any, + ) -> types.CodeType: """ This handles the case of generating a resume into code generated to resume something else. We want to always generate starting @@ -517,7 +538,7 @@ def generate_based_on_original_code_object( def find_new_offset( instructions: list[Instruction], code_options: dict[str, Any] - ): + ) -> None: nonlocal new_offset (target,) = (i for i in instructions if i.offset == offset) # match the functions starting at the last instruction as we have added a prefix @@ -541,7 +562,7 @@ def find_new_offset( def remap_block_offsets( instructions: list[Instruction], code_options: dict[str, Any] - ): + ) -> None: # NOTE: each prefix block generates exactly one PUSH_EXC_INFO, # so we can tell which block a prefix PUSH_EXC_INFO belongs to, # by counting. Then we can use meta.prefix_block-target_offset_remap diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index a109d11e473d..58ed0da5fb2d 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Side effect tracking and management for TorchDynamo's compilation system. @@ -28,11 +26,12 @@ import inspect import warnings import weakref -from collections.abc import MutableMapping +from collections.abc import Generator, MutableMapping from types import CellType from typing import Any, Optional, TYPE_CHECKING import torch.nn +from torch._dynamo.variables.misc import AutogradFunctionContextVariable from . import graph_break_hints, utils, variables from .bytecode_transformation import ( @@ -58,21 +57,25 @@ if TYPE_CHECKING: - from torch._dynamo.symbolic_convert import InstructionTranslator + from torch._dynamo.output_graph import OutputGraph + from torch._dynamo.symbolic_convert import InstructionTranslatorBase + from torch._dynamo.variables.lists import ListVariable -def _manual_dict_setitem(dict_from, dict_to, mro_index): +def _manual_dict_setitem( + dict_from: dict[Any, Any], dict_to: dict[Any, Any], mro_index: int +) -> None: # Carefully calls the dict or OrderedDict `clear` or `__setitem__`. We have # to be careful because we don't want to trigger the user defined object # setitem or clear. The mro_index is used to find the dict/OrderedDict from # the class mro. dict_class = type(dict_to).__mro__[mro_index] - dict_class.clear(dict_to) + dict_class.clear(dict_to) # type: ignore[attr-defined] for k, v in dict_from.items(): - dict_class.__setitem__(dict_to, k, v) + dict_class.__setitem__(dict_to, k, v) # type: ignore[index] -def _manual_list_update(list_from, list_to): +def _manual_list_update(list_from: list[Any], list_to: list[Any]) -> None: list.clear(list_to) list.extend(list_to, list_from) @@ -103,13 +106,27 @@ class SideEffects: def __init__( self, - output_graph, - id_to_variable=None, - store_attr_mutations=None, - keepalive=None, - save_for_backward=None, - tensor_hooks=None, - ): + output_graph: "OutputGraph", + id_to_variable: Optional[dict[int, VariableTracker]] = None, + store_attr_mutations: Optional[ + dict[VariableTracker, dict[str, VariableTracker]] + ] = None, + keepalive: Optional[list[Any]] = None, + save_for_backward: Optional[ + list[tuple[AutogradFunctionContextVariable, list[VariableTracker]]] + ] = None, + tensor_hooks: Optional[ + dict[ + int, + tuple[ + "variables.TensorVariable", + VariableTracker, + "variables.RemovableHandleVariable", + str, + ], + ] + ] = None, + ) -> None: super().__init__() self.output_graph_weakref = weakref.ref(output_graph) self.id_to_variable = id_to_variable or {} @@ -122,7 +139,7 @@ def __init__( self._has_existing_dict_mutation = False # Track Compiled Autograd final callbacks that must be called at the end of Compiled Autograd backward graph. # Only applicable if this graph is created from Dynamo tracing in Compiled Autograd. - self.ca_final_callbacks_var = None + self.ca_final_callbacks_var: Optional[ListVariable] = None # Tracks VariableTracker objects whose mutations can be skipped. # For normal mutated variables, Dynamo generates code to replay/reconstruct @@ -131,14 +148,14 @@ def __init__( # execution but don't need to be replayed in the generated code. # Used for temporary mutations in contexts like torch.func.functional_call, # where module parameters/buffers are modified but later restored. - self.ignore_mutation_on_these_variables = set() + self.ignore_mutation_on_these_variables: set[VariableTracker] = set() - def ignore_mutations_on(self, var): + def ignore_mutations_on(self, var: VariableTracker) -> None: """Mutations to this variable will be executed but not not tracked, typically used for temporary mutations that are later restored.""" self.ignore_mutation_on_these_variables.add(var) - def stop_ignoring_mutations_on(self, var): + def stop_ignoring_mutations_on(self, var: VariableTracker) -> None: """Remove a variable from the skip mutation set, restoring normal mutation tracking.""" if var in self.ignore_mutation_on_these_variables: self.ignore_mutation_on_these_variables.remove(var) @@ -175,10 +192,12 @@ def diff(self, other: "SideEffects") -> Optional[str]: else: return None - def clone(self): + def clone(self) -> "SideEffects": """Create a shallow copy""" + ref = self.output_graph_weakref() + assert ref is not None return self.__class__( - output_graph=self.output_graph_weakref(), + output_graph=ref, id_to_variable=dict(self.id_to_variable), store_attr_mutations={ k: dict(v) for k, v in self.store_attr_mutations.items() @@ -188,36 +207,36 @@ def clone(self): tensor_hooks=self.tensor_hooks, ) - def __contains__(self, item): + def __contains__(self, item: Any) -> bool: return id(item) in self.id_to_variable - def __getitem__(self, item): + def __getitem__(self, item: Any) -> VariableTracker: return self.id_to_variable[id(item)] - def should_allow_side_effects_under_checkpoint(self): + def should_allow_side_effects_under_checkpoint(self) -> bool: output_graph = self.output_graph_weakref() - return ( + return bool( output_graph and output_graph.current_tx.output.current_tracer.under_activation_checkpoint and output_graph.current_tx.output.current_tracer.allow_side_effects_under_checkpoint ) - def should_allow_externally_visible_side_effects_in_subtracer(self): + def should_allow_externally_visible_side_effects_in_subtracer(self) -> bool: output_graph = self.output_graph_weakref() - return ( + return bool( output_graph and output_graph.current_tx.output.current_tracer.unsafe_allow_externally_visible_side_effects ) - def is_reconstructing_generator(self): + def is_reconstructing_generator(self) -> bool: output_graph = self.output_graph_weakref() - return ( + return bool( output_graph and output_graph.current_tx.output.current_tracer.is_reconstructing_generator ) - def check_allowed_side_effect(self, item: VariableTracker): + def check_allowed_side_effect(self, item: VariableTracker) -> bool: from torch._dynamo.variables.misc import AutogradFunctionContextVariable # People do things like self.dim = dim inside autograd.Function. @@ -244,15 +263,24 @@ def check_allowed_side_effect(self, item: VariableTracker): explanation="This is not supported.", hints=[], ) + return False - def store_attr(self, item: VariableTracker, name: str, value: VariableTracker): + def store_attr( + self, item: VariableTracker, name: str, value: VariableTracker + ) -> None: assert self.is_attribute_mutation(item) self.check_allowed_side_effect(item) if item not in self.store_attr_mutations: self.store_attr_mutations[item] = {} self.store_attr_mutations[item][name] = value - def load_attr(self, item, name, deleted_ok=False, check=False): + def load_attr( + self, + item: VariableTracker, + name: str, + deleted_ok: bool = False, + check: bool = False, + ) -> VariableTracker: if check: assert self.is_attribute_mutation(item) result = self.store_attr_mutations[item][name] @@ -265,7 +293,7 @@ def load_attr(self, item, name, deleted_ok=False, check=False): ) return result - def store_cell(self, cellvar, value): + def store_cell(self, cellvar: VariableTracker, value: VariableTracker) -> None: if cellvar.is_immutable(): unimplemented_v2( gb_type="Write to immutable cell", @@ -277,7 +305,7 @@ def store_cell(self, cellvar, value): assert isinstance(value, variables.VariableTracker) self.store_attr(cellvar, "cell_contents", value) - def load_cell(self, cellvar): + def load_cell(self, cellvar: VariableTracker) -> VariableTracker: assert isinstance(cellvar, variables.CellVariable) if self.has_pending_mutation_of_attr(cellvar, "cell_contents"): return self.load_attr(cellvar, "cell_contents", check=False) @@ -290,17 +318,19 @@ def load_cell(self, cellvar): hints=[*graph_break_hints.USER_ERROR], ) - def load_global(self, gvar: VariableTracker, name: str): + def load_global(self, gvar: VariableTracker, name: str) -> VariableTracker: assert isinstance(gvar, variables.VariableTracker) return self.load_attr(gvar, name) - def store_global(self, gvar: VariableTracker, name: str, value: VariableTracker): + def store_global( + self, gvar: VariableTracker, name: str, value: VariableTracker + ) -> None: assert isinstance(gvar, variables.VariableTracker) assert isinstance(value, variables.VariableTracker) self.store_attr(gvar, name, value) @staticmethod - def cls_supports_mutation_side_effects(cls): + def cls_supports_mutation_side_effects(cls: type) -> bool: return inspect.getattr_static(cls, "__getattribute__", None) in ( object.__getattribute__, dict.__getattribute__, @@ -313,20 +343,20 @@ def cls_supports_mutation_side_effects(cls): BaseException.__getattribute__, ) - def is_attribute_mutation(self, item): + def is_attribute_mutation(self, item: VariableTracker) -> bool: return isinstance(item.mutation_type, AttributeMutation) - def has_pending_mutation(self, item): + def has_pending_mutation(self, item: VariableTracker) -> bool: return self.is_attribute_mutation(item) and bool( self.store_attr_mutations.get(item) ) - def has_pending_mutation_of_attr(self, item, name): + def has_pending_mutation_of_attr(self, item: VariableTracker, name: str) -> bool: return self.is_attribute_mutation( item ) and name in self.store_attr_mutations.get(item, ()) - def is_modified(self, item): + def is_modified(self, item: VariableTracker) -> bool: if item.is_immutable(): return False if isinstance(item.mutation_type, (AttributeMutationNew, ValueMutationNew)): @@ -341,14 +371,14 @@ def is_modified(self, item): if self.is_attribute_mutation(item): return item in self.store_attr_mutations - return item.mutation_type.is_modified + return item.mutation_type.is_modified # type: ignore[attr-defined] def _track_obj( self, item: Any, variable: VariableTracker, - mutation_type_cls=ValueMutationExisting, - ): + mutation_type_cls: type = ValueMutationExisting, + ) -> VariableTracker: """Start tracking an existing or new variable for mutation""" if id(item) in self.id_to_variable: raise AssertionError( @@ -370,7 +400,7 @@ def track_object_existing( self, item: Any, variable: VariableTracker, - ): + ) -> VariableTracker: return self._track_obj( item, variable, @@ -382,8 +412,8 @@ def track_object_new( cls_source: Source, user_cls: Any, variable_cls: Any, - options, - ): + options: dict[str, Any], + ) -> VariableTracker: if user_cls is torch.autograd.function.FunctionCtx: with warnings.catch_warnings(record=True): obj = torch.autograd.Function() @@ -398,7 +428,7 @@ def track_object_new( self.keepalive.append(obj) return variable - def get_variable_cls(self, user_cls): + def get_variable_cls(self, user_cls: type) -> type: from torch.overrides import TorchFunctionMode from .variables.ctx_manager import GenericContextWrappingVariable @@ -439,11 +469,11 @@ def get_variable_cls(self, user_cls): def get_example_value( self, - base_cls_vt, - cls_vt, - init_args, - ): - user_cls = cls_vt.value + base_cls_vt: VariableTracker, + cls_vt: VariableTracker, + init_args: list[VariableTracker], + ) -> Any: + user_cls = cls_vt.value # type: ignore[attr-defined] if issubclass(user_cls, torch.nn.Module): # TODO(anijain2305) - Is it possible to remove this specialization? obj = nn_module_new(user_cls) @@ -470,10 +500,10 @@ def get_example_value( def track_new_user_defined_object( self, - base_cls_vt, - cls_vt, - init_args, - ): + base_cls_vt: VariableTracker, + cls_vt: VariableTracker, + init_args: list[VariableTracker], + ) -> VariableTracker: """ Creates a UserDefinedObjectVariable (or its subclass) variable tracker and mark it for attribute mutation tracking. @@ -483,7 +513,7 @@ def track_new_user_defined_object( base_cls_vt.__new__(user_cls, *init_args) """ cls_source = cls_vt.source - user_cls = cls_vt.value + user_cls = cls_vt.value # type: ignore[attr-defined] variable_cls = self.get_variable_cls(user_cls) obj = self.get_example_value(base_cls_vt, cls_vt, init_args) @@ -500,7 +530,7 @@ def track_new_user_defined_object( def track_cell_new( self, - ): + ) -> VariableTracker: obj = object() variable = variables.CellVariable( mutation_type=AttributeMutationNew(), @@ -511,7 +541,7 @@ def track_cell_new( def track_cell_existing( self, source: Optional[Source], cell: CellType, contents: VariableTracker - ): + ) -> VariableTracker: variable = variables.CellVariable( # We don't support mutation to cell without source because we need # source to properly codegen the mutations. @@ -523,7 +553,7 @@ def track_cell_existing( self.keepalive.append(cell) return variable - def track_global_existing(self, source: Source, item: Any): + def track_global_existing(self, source: Source, item: Any) -> VariableTracker: variable = variables.NewGlobalVariable( mutation_type=AttributeMutationExisting(), source=source, @@ -532,11 +562,15 @@ def track_global_existing(self, source: Source, item: Any): self.keepalive.append(item) return variable - def track_save_for_backward(self, ctx, args): + def track_save_for_backward( + self, ctx: VariableTracker, args: list[VariableTracker] + ) -> None: assert isinstance(ctx, variables.AutogradFunctionContextVariable) self.save_for_backward.append((ctx, args)) - def track_tensor_variables_from_runahead_side_effects(self, other): + def track_runahead_tensor_and_symvar_side_effects( + self, other: "SideEffects" + ) -> None: # In higher order ops we want to keep track of tensors seen in the # speculate_subgraph so that we don't lift them again as a new input in # other speculate_subgraph or in the root tracer. @@ -544,16 +578,16 @@ def track_tensor_variables_from_runahead_side_effects(self, other): other_id = id(other_item) other_variable = other.id_to_variable[other_id] if other_id not in self.id_to_variable and isinstance( - other_variable, variables.TensorVariable + other_variable, (variables.TensorVariable, variables.SymNodeVariable) ): self.track_object_existing(other_item, other_variable) - def prune_dead_object_new(self, tx): + def prune_dead_object_new(self, tx: "InstructionTranslatorBase") -> None: # Avoid VT cycles from e.g., recursive function. visited: set[VariableTracker] = set() live_new_objects: set[VariableTracker] = set() - def visit(var: VariableTracker): + def visit(var: VariableTracker) -> None: if var in visited: return visited.add(var) @@ -569,7 +603,7 @@ def visit(var: VariableTracker): self.store_attr_mutations[var], ) - def is_live(var: VariableTracker): + def is_live(var: VariableTracker) -> bool: if isinstance(var.mutation_type, AttributeMutationNew): return var in live_new_objects return True @@ -612,7 +646,7 @@ def is_live(var: VariableTracker): k: v for k, v in self.store_attr_mutations.items() if is_live(k) } - def mutation(self, var): + def mutation(self, var: VariableTracker) -> None: if var in self.ignore_mutation_on_these_variables: return @@ -626,13 +660,13 @@ def mutation(self, var): ): self._has_existing_dict_mutation = True - def has_existing_dict_mutation(self): + def has_existing_dict_mutation(self) -> bool: return self._has_existing_dict_mutation - def _get_modified_vars(self): + def _get_modified_vars(self) -> list[VariableTracker]: return [var for var in self.id_to_variable.values() if self.is_modified(var)] - def codegen_save_tempvars(self, cg: PyCodegen): + def codegen_save_tempvars(self, cg: PyCodegen) -> None: # We must codegen modified VT to their source by default, so that # mutation and aliasing are properly accounted for. # @@ -692,7 +726,7 @@ def codegen_save_tempvars(self, cg: PyCodegen): # base_cls.__new__(user_cls, *args) if isinstance(var, variables.UserDefinedObjectVariable): - def load_new_method(): + def load_new_method() -> None: assert var.base_cls_vt is not None cg(var.base_cls_vt) # type: ignore[attr-defined] cg.extend_output([cg.create_load_attr("__new__")]) @@ -702,14 +736,15 @@ def load_new_method(): cg.add_push_null( lambda: cg.load_import_from(utils.__name__, "object_new") ) + assert var.mutation_type.cls_source is not None cg(var.mutation_type.cls_source) # Generate the args to the __new__ method - for arg in var.init_args: + for arg in var.init_args: # type: ignore[attr-defined] cg(arg) # Call the __new__ method - cg.extend_output(create_call_function(1 + len(var.init_args), False)) + cg.extend_output(create_call_function(1 + len(var.init_args), False)) # type: ignore[attr-defined] cg.add_cache(var) var.source = LocalSource(cg.tempvars[var]) @@ -726,7 +761,13 @@ def load_new_method(): ] ) - def register_hook(self, tensor, hook, handle, name): + def register_hook( + self, + tensor: "variables.TensorVariable", + hook: VariableTracker, + handle: "variables.RemovableHandleVariable", + name: str, + ) -> None: assert isinstance(tensor, variables.TensorVariable) assert isinstance(hook, variables.VariableTracker) assert ( @@ -742,10 +783,10 @@ def register_hook(self, tensor, hook, handle, name): assert not handle.idx handle.idx = idx - def remove_hook(self, idx): + def remove_hook(self, idx: int) -> None: del self.tensor_hooks[idx] - def codegen_hooks(self, cg): + def codegen_hooks(self, cg: PyCodegen) -> None: for ( tensor, hook, @@ -787,7 +828,7 @@ def codegen_hooks(self, cg): # - The handle's exact user-specified name, "user_code_variable_name", is discerned and associated during STORE_FAST. assert tensor.source, "Hooks on non input tensors NYI - should not get here" - def gen_fn(): + def gen_fn() -> None: cg(tensor) cg.extend_output([cg.create_load_attr(name)]) @@ -799,16 +840,17 @@ def gen_fn(): # be associated with the return value of register_hook(). This consumes the top of stack. cg.add_cache(handle) - def get_ca_final_callbacks_var(self): + def get_ca_final_callbacks_var(self) -> "variables.ListVariable": from .variables.base import ValueMutationNew if self.ca_final_callbacks_var is None: self.ca_final_callbacks_var = variables.ListVariable( [], mutation_type=ValueMutationNew() ) + return self.ca_final_callbacks_var - def codegen_update_mutated(self, cg: PyCodegen): + def codegen_update_mutated(self, cg: PyCodegen) -> None: suffixes = [] for var in self._get_modified_vars(): if isinstance(var, variables.ListVariable): @@ -1101,7 +1143,7 @@ def codegen_update_mutated(self, cg: PyCodegen): cg.pop_top() elif isinstance(var, variables.RandomVariable): # set correct random seed state - def gen_fn(): + def gen_fn() -> None: cg(var.source) # type: ignore[attr-defined] cg.load_attr("setstate") @@ -1121,7 +1163,7 @@ def gen_fn(): for suffix in reversed(suffixes): cg.extend_output(suffix) - def is_empty(self): + def is_empty(self) -> bool: return not ( any(map(self.is_modified, self.id_to_variable.values())) or self.tensor_hooks @@ -1129,13 +1171,15 @@ def is_empty(self): or self.tensor_hooks ) - def clear(self): + def clear(self) -> None: self.keepalive.clear() self.id_to_variable.clear() @contextlib.contextmanager -def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): +def allow_side_effects_under_checkpoint( + tx: "InstructionTranslatorBase", +) -> Generator[None, None, None]: assert tx.output.current_tracer.under_activation_checkpoint orig_val = tx.output.current_tracer.allow_side_effects_under_checkpoint try: @@ -1146,7 +1190,9 @@ def allow_side_effects_under_checkpoint(tx: "InstructionTranslator"): @contextlib.contextmanager -def allow_externally_visible_side_effects_in_subtracer(tx: "InstructionTranslator"): +def allow_externally_visible_side_effects_in_subtracer( + tx: "InstructionTranslatorBase", +) -> Generator[None, None, None]: orig_val = tx.output.current_tracer.unsafe_allow_externally_visible_side_effects try: tx.output.current_tracer.unsafe_allow_externally_visible_side_effects = True @@ -1156,7 +1202,9 @@ def allow_externally_visible_side_effects_in_subtracer(tx: "InstructionTranslato @contextlib.contextmanager -def disallow_side_effects_in_generator(tx: "InstructionTranslator"): +def disallow_side_effects_in_generator( + tx: "InstructionTranslatorBase", +) -> Generator[None, None, None]: orig_val = tx.output.current_tracer.is_reconstructing_generator try: tx.output.current_tracer.is_reconstructing_generator = True diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index e2ee525ed644..7d700b2539c9 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -751,6 +751,22 @@ def name(self): return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})" +@dataclasses.dataclass(frozen=True) +class DataclassFieldsSource(ChainedSource): + def reconstruct(self, codegen: "PyCodegen"): + codegen.add_push_null( + lambda: codegen.load_import_from(utils.__name__, "dataclass_fields") + ) + codegen(self.base) + codegen.extend_output(create_call_function(1, False)) + + def guard_source(self): + return self.base.guard_source() + + def name(self): + return f"___dataclass_fields({self.base.name()})" + + @dataclasses.dataclass(frozen=True) class TypeSource(ChainedSource): def __post_init__(self): diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 769fb510fdf1..181b8ee9b042 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -943,6 +943,7 @@ def handle_graph_break( self.output.add_output_instructions( [create_instruction("KW_NAMES", argval=kw_names)] ) + assert inst.arg is not None call_insts = create_call_function(inst.arg, False) call_insts[-1].copy_positions(inst) self.output.add_output_instructions(call_insts) @@ -3443,7 +3444,7 @@ def __init__( side_effects.store_cell(cell_var, contents_var) else: cell_var = side_effects.track_cell_new() - cell_var.local_name = name + cell_var.local_name = name # type: ignore[attr-defined] self.symbolic_locals[name] = cell_var # Populate `symbolic_locals` with cells captured by this frame, @@ -3461,7 +3462,7 @@ def __init__( cell_var = side_effects.track_cell_existing( cell_source, cell, contents_var ) - cell_var.local_name = name + cell_var.local_name = name # type: ignore[attr-defined] self.symbolic_locals[name] = cell_var self.symbolic_torch_function_state = SymbolicTorchFunctionState( diff --git a/torch/_dynamo/tensor_version_op.py b/torch/_dynamo/tensor_version_op.py index c1a6fd03ba06..8709c5618d85 100644 --- a/torch/_dynamo/tensor_version_op.py +++ b/torch/_dynamo/tensor_version_op.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """This module implements tensor version operations for Dynamo tracing. It provides primitives for handling tensor versioning during tracing, particularly in the @@ -18,7 +16,11 @@ Note this is similar to how no_grad is handled. """ +from contextlib import AbstractContextManager +from typing import Any + import torch +from torch import SymInt from torch._prims import _make_prim, RETURN_TYPE from torch._subclasses import FakeTensorMode from torch._subclasses.functional_tensor import FunctionalTensorMode @@ -33,13 +35,14 @@ ) -@_tensor_version.py_impl(FakeTensorMode) -def _tensor_version_fake(fake_mode, self_tensor): +@_tensor_version.py_impl(FakeTensorMode) # type: ignore[misc] +def _tensor_version_fake(fake_mode: FakeTensorMode, self_tensor: Any) -> SymInt: """ The initial dynamo capture of _tensor_version + _unsafe_set_version_counter turns the `._version` into an unbacked SymInt so that we don't need to specialize on the `._version` of input tensors to the graph. """ + assert fake_mode.shape_env is not None return fake_mode.shape_env.create_unbacked_symint() @@ -53,11 +56,15 @@ def _tensor_version_fake(fake_mode, self_tensor): torch.fx.node.has_side_effect(_unsafe_set_version_counter) -@_tensor_version.py_impl(FunctionalTensorMode) -def _tensor_version_functional(mode, self): +@_tensor_version.py_impl(FunctionalTensorMode) # type: ignore[misc] +def _tensor_version_functional(mode: FunctionalTensorMode, self: Any) -> int: return self._version -@_unsafe_set_version_counter.py_impl(FunctionalTensorMode) -def _unsafe_set_version_counter_functional(ctx, tensors, versions): +@_unsafe_set_version_counter.py_impl(FunctionalTensorMode) # type: ignore[misc] +def _unsafe_set_version_counter_functional( + ctx: AbstractContextManager[Any], + tensors: tuple[torch.Tensor, ...], + versions: tuple[int, ...], +) -> None: torch._C._autograd._unsafe_set_version_counter(tensors, versions) diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py index dc7a44684051..230aac4794f2 100644 --- a/torch/_dynamo/test_case.py +++ b/torch/_dynamo/test_case.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """Testing utilities for Dynamo, providing a specialized TestCase class and test running functionality. This module extends PyTorch's testing framework with Dynamo-specific testing capabilities. @@ -18,7 +16,7 @@ import re import sys import unittest -from typing import Union +from typing import Any, Callable, Union import torch import torch.testing @@ -151,7 +149,12 @@ class CPythonTestCase(TestCase): fail = unittest.TestCase.fail failureException = unittest.TestCase.failureException - def compile_fn(self, fn, backend, nopython): + def compile_fn( + self, + fn: Callable[..., Any], + backend: Union[str, Callable[..., Any]], + nopython: bool, + ) -> Callable[..., Any]: # We want to compile only the test function, excluding any setup code # from unittest method = getattr(self, self._testMethodName) @@ -159,7 +162,7 @@ def compile_fn(self, fn, backend, nopython): setattr(self, self._testMethodName, method) return fn - def _dynamo_test_key(self): + def _dynamo_test_key(self) -> str: suffix = super()._dynamo_test_key() test_cls = self.__class__ test_file = inspect.getfile(test_cls).split(os.sep)[-1].split(".")[0] diff --git a/torch/_dynamo/test_minifier_common.py b/torch/_dynamo/test_minifier_common.py index 32d10b53da99..4e4135666d56 100644 --- a/torch/_dynamo/test_minifier_common.py +++ b/torch/_dynamo/test_minifier_common.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """Common utilities for testing Dynamo's minifier functionality. This module provides the base infrastructure for running minification tests in Dynamo. @@ -25,7 +23,8 @@ import sys import tempfile import traceback -from typing import Optional +from collections.abc import Sequence +from typing import Any, Optional, Union from unittest.mock import patch import torch @@ -40,7 +39,7 @@ class MinifierTestResult: minifier_code: str repro_code: str - def _get_module(self, t): + def _get_module(self, t: str) -> str: match = re.search(r"class Repro\(torch\.nn\.Module\):\s+([ ].*\n| *\n)+", t) assert match is not None, "failed to find module" r = match.group(0) @@ -48,7 +47,7 @@ def _get_module(self, t): r = re.sub(r"\n{3,}", "\n\n", r) return r.strip() - def get_exported_program_path(self): + def get_exported_program_path(self) -> Optional[str]: # Extract the exported program file path from AOTI minifier's repro.py # Regular expression pattern to match the file path pattern = r'torch\.export\.load\(\s*["\'](.*?)["\']\s*\)' @@ -60,10 +59,10 @@ def get_exported_program_path(self): return file_path return None - def minifier_module(self): + def minifier_module(self) -> str: return self._get_module(self.minifier_code) - def repro_module(self): + def repro_module(self) -> str: return self._get_module(self.repro_code) @@ -71,7 +70,7 @@ class MinifierTestBase(torch._dynamo.test_case.TestCase): DEBUG_DIR = tempfile.mkdtemp() @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: super().setUpClass() if not os.path.exists(cls.DEBUG_DIR): cls.DEBUG_DIR = tempfile.mkdtemp() @@ -94,14 +93,14 @@ def setUpClass(cls): ) @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: if os.getenv("PYTORCH_KEEP_TMPDIR", "0") != "1": shutil.rmtree(cls.DEBUG_DIR) else: print(f"test_minifier_common tmpdir kept at: {cls.DEBUG_DIR}") cls._exit_stack.close() # type: ignore[attr-defined] - def _gen_codegen_fn_patch_code(self, device, bug_type): + def _gen_codegen_fn_patch_code(self, device: str, bug_type: str) -> str: assert bug_type in ("compile_error", "runtime_error", "accuracy") return f"""\ {torch._dynamo.config.codegen_config()} @@ -109,7 +108,9 @@ def _gen_codegen_fn_patch_code(self, device, bug_type): torch._inductor.config.{"cpp" if device == "cpu" else "triton"}.inject_relu_bug_TESTING_ONLY = {bug_type!r} """ - def _maybe_subprocess_run(self, args, *, isolate, cwd=None): + def _maybe_subprocess_run( + self, args: Sequence[Any], *, isolate: bool, cwd: Optional[str] = None + ) -> subprocess.CompletedProcess[bytes]: if not isolate: assert len(args) >= 2, args assert args[0] == "python3", args @@ -174,7 +175,9 @@ def _maybe_subprocess_run(self, args, *, isolate, cwd=None): # Run `code` in a separate python process. # Returns the completed process state and the directory containing the # minifier launcher script, if `code` outputted it. - def _run_test_code(self, code, *, isolate): + def _run_test_code( + self, code: str, *, isolate: bool + ) -> tuple[subprocess.CompletedProcess[bytes], Union[str, Any]]: proc = self._maybe_subprocess_run( ["python3", "-c", code], isolate=isolate, cwd=self.DEBUG_DIR ) @@ -190,8 +193,13 @@ def _run_test_code(self, code, *, isolate): # Runs the minifier launcher script in `repro_dir` def _run_minifier_launcher( - self, repro_dir, isolate, *, minifier_args=(), repro_after=None - ): + self, + repro_dir: str, + isolate: bool, + *, + minifier_args: Sequence[Any] = (), + repro_after: Optional[str] = None, + ) -> tuple[subprocess.CompletedProcess[bytes], str]: self.assertIsNotNone(repro_dir) launch_file = _as_posix_path(os.path.join(repro_dir, "minifier_launcher.py")) with open(launch_file) as f: @@ -212,7 +220,9 @@ def _run_minifier_launcher( return launch_proc, launch_code # Runs the repro script in `repro_dir` - def _run_repro(self, repro_dir, *, isolate=True): + def _run_repro( + self, repro_dir: str, *, isolate: bool = True + ) -> tuple[subprocess.CompletedProcess[bytes], str]: self.assertIsNotNone(repro_dir) repro_file = _as_posix_path(os.path.join(repro_dir, "repro.py")) with open(repro_file) as f: @@ -230,7 +240,7 @@ def _run_repro(self, repro_dir, *, isolate=True): # `run_code` is the code to run for the test case. # `patch_code` is the code to be patched in every generated file; usually # just use this to turn on bugs via the config - def _gen_test_code(self, run_code, repro_after, repro_level): + def _gen_test_code(self, run_code: str, repro_after: str, repro_level: int) -> str: repro_after_line = "" if repro_after == "aot_inductor": repro_after_line = ( @@ -263,7 +273,13 @@ def _gen_test_code(self, run_code, repro_after, repro_level): # isolate=True only if the bug you're testing would otherwise # crash the process def _run_full_test( - self, run_code, repro_after, expected_error, *, isolate, minifier_args=() + self, + run_code: str, + repro_after: str, + expected_error: Optional[str], + *, + isolate: bool, + minifier_args: Sequence[Any] = (), ) -> Optional[MinifierTestResult]: if isolate: repro_level = 3 diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 88d67822b0b7..4ff88a25bce3 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - """ Tracing rules and policies for TorchDynamo compilation decisions. @@ -37,7 +35,6 @@ import sys import traceback import types -import typing import unittest from collections import defaultdict from pathlib import Path @@ -73,6 +70,7 @@ UserFunctionVariable, UserMethodVariable, ) +from .variables.base import VariableTracker np: Optional[types.ModuleType] = None @@ -82,10 +80,6 @@ pass -if typing.TYPE_CHECKING: - from .variables.base import VariableTracker - - """ A note on skip/inline rules: @@ -153,7 +147,14 @@ """ -manual_torch_name_rule_map: dict[str, Any] = { +manual_torch_name_rule_map: dict[ + str, + Union[ + type[TorchInGraphFunctionVariable], + type[SkipFunctionVariable], + type[UserFunctionVariable], + ], +] = { "torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable, "torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable, "torch.overrides.is_tensor_like": TorchInGraphFunctionVariable, @@ -176,7 +177,6 @@ "torch.compiler.is_compiling": TorchInGraphFunctionVariable, "torch.compiler.is_dynamo_compiling": TorchInGraphFunctionVariable, "torch.compiler.is_exporting": TorchInGraphFunctionVariable, - "torch.autograd._profiler_enabled": SkipFunctionVariable, "torch._C._to_dlpack": SkipFunctionVariable, "torch.to_dlpack": SkipFunctionVariable, # We graph break on RNG state setters or getters like @@ -2434,6 +2434,7 @@ "torch.atleast_3d", "torch.autograd._calculate_shape", "torch.autograd._is_checkpoint_valid", + "torch.autograd._profiler_enabled", "torch.autograd._make_grads", "torch.autograd._register_py_tensor_class_for_device", "torch.autograd._tensor_or_tensors_to_tuple", @@ -2988,7 +2989,10 @@ def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]: if ".py#" not in k: obj = load_object(k) else: - obj = _module_dir(torch) + k[len("torch/") :] + torch_dir = _module_dir(torch) + if torch_dir is None: + continue + obj = torch_dir + k[len("torch/") :] if obj is not None: if is_lru_cache_wrapped_function(obj): obj = obj.__wrapped__ @@ -3001,7 +3005,7 @@ def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]: return d -def _load_obj_from_str(fully_qualified_name): +def _load_obj_from_str(fully_qualified_name: str) -> Any: module, obj_name = fully_qualified_name.rsplit(".", maxsplit=1) return getattr(importlib.import_module(module), obj_name) @@ -3011,7 +3015,7 @@ def _load_obj_from_str(fully_qualified_name): """ -def load_object(name): +def load_object(name: str) -> Any: try: x = name.split("#") if len(x) == 2: @@ -3032,7 +3036,7 @@ def load_object(name): @functools.cache -def get_tensor_method(): +def get_tensor_method() -> frozenset[Any]: disallowed_tensor_methods = {"__new__", "_make_wrapper_subclass", "_make_subclass"} s = set() for name in dir(torch.Tensor): @@ -3061,7 +3065,7 @@ def get_tensor_method(): """ -def is_aten_op_or_tensor_method(obj): +def is_aten_op_or_tensor_method(obj: Any) -> bool: return obj in get_tensor_method() or isinstance( obj, (torch._ops.OpOverloadPacket, torch._ops.OpOverload), @@ -3097,16 +3101,16 @@ def __call__(self) -> set[int]: self.function_ids = value return self.function_ids - def get_name(self, idx: int, default: str): + def get_name(self, idx: int, default: str) -> str: self() # lazy init assert self.function_names is not None return self.function_names.get(idx, default) - def add(self, idx: int): + def add(self, idx: int) -> None: function_ids = self() # lazy init function_ids.add(idx) - def remove(self, idx: int): + def remove(self, idx: int) -> None: function_ids = self() if idx in function_ids: function_ids.remove(idx) @@ -3174,7 +3178,7 @@ def _numpy_function_ids() -> dict[int, str]: "sample", } - def is_supported(k, v, mod): + def is_supported(k: str, v: Any, mod: Any) -> bool: if not callable(v): return False if not getattr(v, "__module__", None): @@ -3233,53 +3237,53 @@ def _maybe_init_lazy_module(obj: object) -> None: fn() -def is_callable_allowed(obj) -> bool: +def is_callable_allowed(obj: Any) -> bool: _maybe_init_lazy_module(obj) return id(obj) in _allowed_callable_ids -def is_nonstrict_trace_callable(obj) -> bool: +def is_nonstrict_trace_callable(obj: Any) -> bool: _maybe_init_lazy_module(obj) return id(obj) in _nonstrict_trace_callable_ids -def is_callable_disallowed(obj) -> bool: +def is_callable_disallowed(obj: Any) -> bool: _maybe_init_lazy_module(obj) return id(obj) in _disallowed_callable_ids -def is_forbidden(obj) -> bool: +def is_forbidden(obj: Any) -> bool: _maybe_init_lazy_module(obj) return inspect.getattr_static(obj, "_dynamo_forbidden", False) -def is_builtin_callable(obj) -> bool: +def is_builtin_callable(obj: Any) -> bool: # See also torch/_dynamo/polyfills/loader.py, which removes items in _builtin_function_ids return id(obj) in _builtin_function_ids -def is_builtin_constant(obj) -> bool: +def is_builtin_constant(obj: Any) -> bool: return id(obj) in _builtin_constant_ids -def is_polyfilled_callable(obj) -> bool: +def is_polyfilled_callable(obj: Any) -> bool: # See also @torch._dynamo.decorators.substitute_in_graph(...), which adds items in _polyfilled_function_ids return id(obj) in _polyfilled_function_ids -def is_numpy(obj) -> bool: +def is_numpy(obj: Any) -> bool: if np is None: return False return isinstance(obj, (np.ndarray, np.generic)) or id(obj) in _numpy_function_ids -def is_numpy_dtype(obj) -> bool: +def is_numpy_dtype(obj: Any) -> bool: if np is None: return False return isinstance(obj, np.dtype) -def is_numpy_type_info(obj) -> bool: +def is_numpy_type_info(obj: Any) -> bool: if np is None: return False return isinstance(obj, (np.finfo, np.iinfo)) @@ -3317,7 +3321,7 @@ def is_numpy_type_info(obj) -> bool: ) -def _as_posix_path(path): +def _as_posix_path(path: str) -> str: posix_path = Path(os.path.normpath(path)).as_posix() # os.path.normpath and pathlib.Path remove trailing slash, so we need to add it back if path.endswith((os.path.sep, "/")): @@ -3325,13 +3329,13 @@ def _as_posix_path(path): return posix_path -def _strip_init_py(s): +def _strip_init_py(s: str) -> str: suffix = "__init__.py" s = s.removesuffix(suffix) return _as_posix_path(s) -def _module_dir(m: types.ModuleType): +def _module_dir(m: types.ModuleType) -> Optional[str]: # Protect against a module not exporting __file__ - this can happen for # frozen modules, for example. file = getattr(m, "__file__", None) @@ -3554,27 +3558,36 @@ def _module_dir(m: types.ModuleType): @functools.cache -def get_legacy_mod_inlinelist(): +def get_legacy_mod_inlinelist() -> set[str]: + torch_dir = _module_dir(torch) + if torch_dir is None: + return set() inlinelist = { - _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) + _as_posix_path(torch_dir + m[len("torch.") :].replace(".", "/")) for m in LEGACY_MOD_INLINELIST } return inlinelist @functools.cache -def get_mod_inlinelist(): +def get_mod_inlinelist() -> set[str]: + torch_dir = _module_dir(torch) + if torch_dir is None: + return set() inlinelist = { - _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) + _as_posix_path(torch_dir + m[len("torch.") :].replace(".", "/")) for m in MOD_INLINELIST } return inlinelist @functools.cache -def get_mod_skiplist(): +def get_mod_skiplist() -> set[str]: + torch_dir = _module_dir(torch) + if torch_dir is None: + return set() skiplist = { - _as_posix_path(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) + _as_posix_path(torch_dir + m[len("torch.") :].replace(".", "/")) for m in MOD_SKIPLIST } return skiplist @@ -3631,14 +3644,14 @@ def get_mod_skiplist(): FORCE_SKIP_FILES = {f"{_module_dir(torch)}optim/lr_scheduler.py"} -def _recompile_re(): +def _recompile_re() -> None: global SKIP_DIRS_RE SKIP_DIRS_RE = re.compile( rf"^[^\s<]*({'|'.join(re.escape(_as_posix_path(d)) for d in SKIP_DIRS)})" ) -def add(import_name: str): +def add(import_name: str) -> None: if isinstance(import_name, types.ModuleType): return add(import_name.__name__) assert isinstance(import_name, str) @@ -3660,7 +3673,7 @@ class SkipResult: reason: Optional[str] -def check_file(filename, is_inlined_call=False): +def check_file(filename: Optional[str], is_inlined_call: bool = False) -> SkipResult: """Should skip this file?""" if filename is None: return SkipResult(True, "filename is None") @@ -3698,8 +3711,10 @@ def check_file(filename, is_inlined_call=False): ): return SkipResult(True, "FBCODE_SKIP_TORCHREC_DIRS") + unittest_dir = _module_dir(unittest) if ( - filename.startswith(_module_dir(unittest)) + unittest_dir is not None + and filename.startswith(unittest_dir) and not torch._dynamo.config.enable_trace_unittest ): return SkipResult(True, "unittest") @@ -3754,7 +3769,7 @@ def f3(x, y): """ -def check_verbose(obj, is_inlined_call=False): +def check_verbose(obj: Any, is_inlined_call: bool = False) -> SkipResult: if isinstance( obj, ( @@ -3773,18 +3788,23 @@ def check_verbose(obj, is_inlined_call=False): elif isinstance(obj, types.CodeType): fi = FunctionInfo(None, obj.co_name, obj.co_filename, obj) elif isinstance(obj, (types.FunctionType, types.MethodType)): + filename = getfile(obj) + assert filename is not None fi = FunctionInfo( obj, obj.__name__, - getfile(obj), + filename, obj.__code__, # type: ignore[union-attr] # FIXME Add MethodType.__code__ to typeshed ) else: - fi = FunctionInfo(obj, None, getfile(obj), None) + filename = getfile(obj) + assert filename is not None + fi = FunctionInfo(obj, None, filename, None) # Consulte the central trace rules defined in torch._dynamo.trace_rules. reasons: set[str] = set() rule = lookup_inner(fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons) + assert rule is not None if issubclass( rule, ( @@ -3810,7 +3830,7 @@ def check_verbose(obj, is_inlined_call=False): ) -def check(obj, is_inlined_call=False): +def check(obj: Any, is_inlined_call: bool = False) -> bool: return check_verbose(obj, is_inlined_call).skipped @@ -3821,21 +3841,23 @@ def check(obj, is_inlined_call=False): _recompile_re() -def is_torch_inline_allowed(filename): +def is_torch_inline_allowed(filename: str) -> bool: return any(filename.startswith(d) for d in get_mod_inlinelist()) @functools.cache -def dynamo_dir(): +def dynamo_dir() -> Optional[str]: import torch._dynamo return _module_dir(torch._dynamo) -def is_torch(filename): - if filename.startswith(dynamo_dir()): +def is_torch(filename: str) -> bool: + dynamo_path = dynamo_dir() + if dynamo_path is not None and filename.startswith(dynamo_path): return False - return filename.startswith(_module_dir(torch)) + torch_path = _module_dir(torch) + return torch_path is not None and filename.startswith(torch_path) """ @@ -3843,7 +3865,7 @@ def is_torch(filename): """ -def lookup_callable(obj): +def lookup_callable(obj: Callable[..., Any]) -> Optional[type[VariableTracker]]: if not hashable(obj): return None # Custom allow/disallow in graph takes precedence over the general lookup. @@ -3864,18 +3886,18 @@ def lookup_callable(obj): """ -def lookup(obj): +def lookup(obj: Any) -> Optional[type[VariableTracker]]: return lookup_inner(obj) # also takes config.dont_skip_tracing into account def lookup_inner( - obj, - name=None, - filename=None, - is_direct_call=True, + obj: Any, + name: Optional[str] = None, + filename: Optional[str] = None, + is_direct_call: bool = True, reasons: Union[None, set[str]] = None, -): +) -> Optional[type[VariableTracker]]: result = _lookup_inner( obj, name=name, @@ -3890,12 +3912,15 @@ def lookup_inner( if config.dont_skip_tracing and result is SkipFunctionVariable: if filename is None: filename = getfile(obj) + assert filename is not None filename = _as_posix_path(filename) - dynamo_path = _as_posix_path(_module_dir(torch)) + "_dynamo" - if filename.startswith(dynamo_path) and not filename.endswith( - "test_dont_skip_tracing_functions.py" - ): - return SkipFunctionVariable + torch_dir = _module_dir(torch) + if torch_dir is not None: + dynamo_path = _as_posix_path(torch_dir) + "_dynamo" + if filename.startswith(dynamo_path) and not filename.endswith( + "test_dont_skip_tracing_functions.py" + ): + return SkipFunctionVariable if reasons is not None: reasons.add( "Attempted skip but we are ignoring skips due to torch._dynamo.config.dont_skip_tracing" @@ -3905,12 +3930,12 @@ def lookup_inner( def _lookup_inner( - obj, - name=None, - filename=None, - is_direct_call=True, - reasons: Union[None, set[str]] = None, -): + obj: Any, + name: Optional[str] = None, + filename: Optional[str] = None, + is_direct_call: bool = True, + reasons: Optional[set[str]] = None, +) -> Optional[type[VariableTracker]]: # Step 1: lookup obj's tracing rule in `torch_name_rule_map`. # The rules defined in `torch_name_rule_map` mainly includes two parts: # - Manually defined rules for any functions. @@ -3984,7 +4009,7 @@ def _lookup_inner( filename = getfile(obj) skip_result = check_file(filename, is_direct_call) - if reasons is not None: + if reasons is not None and skip_result.reason is not None: reasons.add(skip_result.reason) if skip_result.skipped: return SkipFunctionVariable @@ -3992,7 +4017,7 @@ def _lookup_inner( return UserFunctionVariable -def clear_lru_cache(): +def clear_lru_cache() -> None: torch._dynamo.trace_rules.get_torch_obj_rule_map.cache_clear() torch._dynamo.trace_rules.get_tensor_method.cache_clear() torch._dynamo.trace_rules.get_legacy_mod_inlinelist.cache_clear() diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 0bd6b7f5e4a0..cf3ed5d135e6 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -102,6 +102,7 @@ Iterable, Iterator, KeysView, + Sequence, ValuesView, ) @@ -1713,9 +1714,15 @@ def get_event_data(self) -> dict[str, Any]: def __init__(self): self.tls = threading.local() + + from . import config + # Generate a unique id for this logger, which we can use in scuba to filter down # to a single python run. - self.id_ = str(uuid.uuid4()) + if config.pt2_compile_id_prefix: + self.id_ = f"{config.pt2_compile_id_prefix}-{uuid.uuid4()}" + else: + self.id_ = str(uuid.uuid4()) # TODO: log to init/id tlparse after I add support for it log.info("ChromiumEventLogger initialized with id %s", self.id_) @@ -2137,8 +2144,18 @@ def torch_clone(x): return result +@overload +def clone_inputs( + example_inputs: dict[str, Union[T, tuple[T, ...]]], +) -> dict[str, list[T]]: ... + + +@overload +def clone_inputs(example_inputs: Sequence[T]) -> list[T]: ... + + def clone_inputs(example_inputs): - res: Union[dict[Any, Any], list[Any]] + res: Union[dict[str, Any], list[Any]] if type(example_inputs) is dict: res = dict(example_inputs) for key, value in res.items(): @@ -2219,7 +2236,7 @@ def torchscript(model, example_inputs, verbose=False): return None -def getfile(obj): +def getfile(obj: Any) -> Optional[str]: try: return inspect.getfile(obj) except (TypeError, OSError): @@ -2422,6 +2439,15 @@ def is_int_specialization_case(value, source): source.guard_source().is_specialized_nn_module() and not config.allow_unspec_int_on_nn_module ) + # integers coming from FSDP modules are considered static. This is + # purely empirical and perhaps we should have a better heuristic. + or ( + source.guard_source().is_fsdp_module() + and not ( + config.allow_unspec_int_on_nn_module + or config.allow_unspec_int_on_fsdp_module + ) + ) or ( source.guard_source().is_unspecialized_builtin_nn_module() and not config.allow_unspec_int_on_nn_module @@ -2572,6 +2598,10 @@ def tuple_iterator_getitem(it, index): return obj[start + index] +def dataclass_fields(cls): + return torch._dynamo.disable(dataclasses.fields)(cls) + + iter_next = next diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 8ae5a4bd6cee..0862d9da8311 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -36,7 +36,6 @@ import sys import traceback import types -import warnings import weakref from collections.abc import MutableMapping from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union @@ -52,6 +51,7 @@ set_feature_use, ) from torch._guards import TracingContext +from torch._higher_order_ops.flat_apply import flat_apply from torch._higher_order_ops.torchbind import call_torchbind from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode @@ -132,6 +132,7 @@ get_locals_to_steal, get_static_address_type, is_frozen_dataclass, + is_function, is_function_or_wrapper, is_invoke_subgraph, is_lru_cache_wrapped_function, @@ -161,6 +162,7 @@ VariableTracker, VariableTrackerMeta, ) +from .builtin import BuiltinVariable from .constant import ConstantVariable, EnumVariable from .ctx_manager import ( AutocastModeVariable, @@ -304,8 +306,7 @@ def safe_has_grad(t): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") + with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter): return hasattr(t, "grad") @@ -445,8 +446,18 @@ def __call__(self, value): if vt.source is None: vt.source = self.source + def _is_deduplicable_sym_variable(value, vt): + # Constants like 0, 1, 2, etc. can be unspecialized as SymNodeVariables sometimes, but we + # should NOT track them. If we use a single SymNodeVariable instance to track them + # across multiple uses, then guards created for one usage will incorrectly apply to + # all other usages of that constant, leading to unnecessary recompilations. + return is_torch_sym(value) and isinstance(vt, SymNodeVariable) + if ( - self._can_lift_attrs_to_inputs(vt) + ( + self._can_lift_attrs_to_inputs(vt) + or _is_deduplicable_sym_variable(value, vt) + ) and value not in self.tx.output.side_effects and not is_wrapper_or_member_descriptor(value) ): @@ -1214,6 +1225,12 @@ def build_key_value(i, k, v): ) and BuiltinMethodVariable.is_supported_builtin_method(value): self.install_guards(GuardBuilder.ID_MATCH) return BuiltinMethodVariable(value, source=self.source) + elif is_function(value) and value in (float.fromhex, float.hex): + self.install_guards(GuardBuilder.ID_MATCH) + return GetAttrVariable( + BuiltinVariable(float, source=self.source), + value.__name__, + ) elif is_function_or_wrapper(value): value, attr_name = unwrap_with_attr_name_if_wrapper(value) # For these wrappers, Dynamo points to the wrapped function, @@ -2994,6 +3011,12 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe ): set_example_value(proxy.node, example_value) return ConstantVariable.create(example_value, **options) + elif ( + isinstance(example_value, (int, float, bool)) + and proxy.node.target is flat_apply + ): + set_example_value(proxy.node, example_value) + return ConstantVariable.create(example_value, **options) elif isinstance(example_value, float) or proxy.node.target in ["hex", "__round__"]: set_example_value(proxy.node, example_value) return ConstantVariable.create(example_value, **options) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 926904396887..137108f5fac3 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1271,6 +1271,19 @@ def call_method( args[1:], ) + if self.fn is float and len(args) == 1 and name in ("fromhex", "hex"): + if isinstance(args[0], ConstantVariable): + try: + fn = getattr(float, name) + res = fn(args[0].as_python_constant()) + return variables.ConstantVariable.create(res) + except (OverflowError, ValueError) as e: + raise_observed_exception( + type(e), + tx, + args=list(map(ConstantVariable.create, e.args)), + ) + if self.fn is object and name == "__init__": # object.__init__ is a no-op return variables.ConstantVariable(None) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 82dd2eb4caea..b874cfaadbc4 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -708,7 +708,7 @@ def speculate_subgraph( if restore_side_effects: new_side_effects = tx.output.side_effects.clone() - prev_side_effects.track_tensor_variables_from_runahead_side_effects( + prev_side_effects.track_runahead_tensor_and_symvar_side_effects( new_side_effects ) tx.output.side_effects = prev_side_effects @@ -991,7 +991,9 @@ def call_function( f"{operands.python_type()}", ) operands_seq = operands.unpack_var_sequence(tx) - if not only_consist_of(operands, (TensorVariable, ConstantVariable)): + if not only_consist_of( + operands, (TensorVariable, ConstantVariable, SymNodeVariable) + ): unimplemented( "Expect operands to be a tuple of pytrees that only consists of tensor leaves." ) @@ -1219,26 +1221,46 @@ def call_function( additional_inputs_seq = additional_inputs.unpack_var_sequence(tx) with discard_graph_changes(tx): - # See NOTE [unspecialize int carry with unbacked symints] # Note: this must be run under discard graph changes. - def create_unbacked_sym_node_var(tx) -> SymNodeVariable: - example_value = _create_unbacked_symint( - tx.output.fake_mode, ignore_fresh_unbacked_symbols=True - ) - proxy = tx.output.current_tracer.create_graph_input( - "unbacked_symint", type(example_value), example_value - ) - return SymNodeVariable.create(tx, proxy, example_value) - - new_operands_seq = [ - ( - create_unbacked_sym_node_var(tx) - if ( - isinstance(carry, ConstantVariable) - and carry.python_type() is int + def unspecialize_carried_inputs(tx, carry) -> VariableTracker: + # See NOTE [unspecialize int carry with unbacked symints] + if ( + isinstance(carry, ConstantVariable) and carry.python_type() is int + ) or isinstance(carry, SymNodeVariable): + example_value = _create_unbacked_symint( + tx.output.fake_mode, ignore_fresh_unbacked_symbols=True + ) + proxy = tx.output.current_tracer.create_graph_input( + "unbacked_symint", type(example_value), example_value ) - or (isinstance(carry, SymNodeVariable)) - else carry + return SymNodeVariable.create(tx, proxy, example_value) + else: + # See NOTE [unspecialize constant tensor carry] + assert isinstance(carry, TensorVariable) + cloned_carry = carry.clone() + cloned_carry.proxy.node.meta["example_value"].constant = None + return cloned_carry + + # clone inputs across subgraphs, to avoid unbacked memoization in fake prop + cond_operands_seq = [ + unspecialize_carried_inputs( + tx, + ( + carry.call_method(tx, "clone", args=(), kwargs={}) + if isinstance(carry, TensorVariable) + else carry + ), + ) + for carry in operands_seq + ] + body_operands_seq = [ + unspecialize_carried_inputs( + tx, + ( + carry.call_method(tx, "clone", args=(), kwargs={}) + if isinstance(carry, TensorVariable) + else carry + ), ) for carry in operands_seq ] @@ -1251,7 +1273,7 @@ def create_unbacked_sym_node_var(tx) -> SymNodeVariable: ) = speculate_subgraph( tx, cond_fn, - new_operands_seq + additional_inputs_seq, + cond_operands_seq + additional_inputs_seq, {}, "while_loop", source_target=self.value, @@ -1316,7 +1338,7 @@ def create_unbacked_sym_node_var(tx) -> SymNodeVariable: ) = speculate_subgraph( tx, body_fn, - new_operands_seq + additional_inputs_seq, + body_operands_seq + additional_inputs_seq, {}, "while_loop", source_target=self.value, diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 93547c79e956..3e0a91b5e922 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -1045,10 +1045,10 @@ class NamedTupleVariable(TupleVariable): *TupleVariable._nonvar_fields, } - def __init__(self, items, tuple_cls, **kwargs) -> None: + def __init__(self, items, tuple_cls, dynamic_attributes=None, **kwargs) -> None: super().__init__(items, **kwargs) self.tuple_cls = tuple_cls - self.dynamic_attributes = {} + self.dynamic_attributes = {} if not dynamic_attributes else dynamic_attributes def is_namedtuple(self): return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable( @@ -1279,8 +1279,13 @@ def as_python_constant(self): raise NotImplementedError return iter([x.as_python_constant() for x in self.items]) + def has_unpack_var_sequence(self, tx): + return True + def unpack_var_sequence(self, tx): - return list(self.items[self.index :]) + r = list(self.items[self.index :]) + self.index = len(self.items) + return r def force_unpack_var_sequence(self, tx) -> list[VariableTracker]: return self.unpack_var_sequence(tx) diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 79471bda50ae..62d0542dcab0 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -316,17 +316,18 @@ def dynamic_getattr(self, tx: "InstructionTranslator", name): real_value = getattr(_input_associated_real_value, name) attr_source = AttrSource(self.source, name) - install_guard(attr_source.make_guard(GuardBuilder.HASATTR)) # Typically we'd want to use variable builder here # but unfortunately id(real_value.__self__) is not id() if is_bound_tensor_method(real_value): + # No need to install the guard because its a bound tensor method from .misc import GetAttrVariable return GetAttrVariable( self, name, source=attr_source, py_type=type(real_value) ) + install_guard(attr_source.make_guard(GuardBuilder.HASATTR)) return VariableTracker.build(tx, real_value, attr_source) def method_attr_ndim(self, tx): diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index c357e158503c..9a83acd61b1d 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -52,7 +52,7 @@ tracable_create_parameter, ) from ..device_interface import get_registered_device_interfaces -from ..exc import unimplemented_v2 +from ..exc import raise_observed_exception, unimplemented_v2 from ..guards import GuardBuilder, install_guard from ..source import ( AttrSource, @@ -142,6 +142,7 @@ torch.cuda.is_initialized, torch.xpu.current_device, torch.xpu.is_initialized, + torch.autograd._profiler_enabled, ] constant_fold_functions = [ @@ -1358,12 +1359,27 @@ def patched_fn(*args, **kwargs): # 2. Create a proxy call to `flat_apply`, then fake-tensor propagate # the call and wrap output into a VariableTracker. proxy = tx.output.create_proxy("call_function", flat_apply, all_args, {}) - out_vt = wrap_fx_proxy(tx, proxy) - # TODO support more output types - # Q: flat_apply will likely pytree_flatten the output for this, then - # how do we intercept the output before flatten, and wrap those? - # - Maybe we can have `flat_apply` return the output spec, so that - # Dynamo can unflatten and wrap the result. + try: + # TODO support more output types once `flat_apply` supports + # pytree-able output types. We can have Dynamo trace through an + # unflatten call (just like we traced through a flatten above) + # to rebuild the actual output VT. + out_vt = wrap_fx_proxy(tx, proxy) + except ( + # From `handle_traced_output`. + torch._dynamo.exc.Unsupported, + # From `flat_apply` assert on output type. + torch._dynamo.exc.TorchRuntimeError, + ): + unimplemented_v2( + gb_type="Unsupported output type for nonstrict_trace-ed function", + context=f"Function: {fn.__name__}", + explanation=( + "For `nonstrict_trace`-ed functions, only basic types (e.g., torch.Tensor, int, list)" + " are allowed as output. The result of this call contains an unsupported type." + ), + hints=[*graph_break_hints.SUPPORTABLE], + ) return out_vt @@ -1378,12 +1394,19 @@ def patched_fn(*args, **kwargs): source = CallFunctionNoArgsSource(self.source) install_guard(source.make_guard(GuardBuilder.EQUALS_MATCH)) # constant fold - return ConstantVariable.create( - self.as_python_constant()( - *[x.as_python_constant() for x in args], - **{k: v.as_python_constant() for k, v in kwargs.items()}, - ), - ) + try: + return ConstantVariable.create( + self.as_python_constant()( + *[x.as_python_constant() for x in args], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ), + ) + except (OverflowError, TypeError, ValueError) as exc: + raise_observed_exception( + type(exc), + tx, + args=list(map(ConstantVariable.create, exc.args)), + ) if self.is_tensor_method(): name = self.value.__name__ diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 6c7e24ef16f0..c08f8099664f 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -31,6 +31,7 @@ import enum import functools import inspect +import itertools import random import sys import threading @@ -58,6 +59,7 @@ from ..source import ( AttrSource, CallFunctionNoArgsSource, + DataclassFieldsSource, GetItemSource, RandomValueSource, TypeSource, @@ -624,11 +626,12 @@ def call_function( return SizeVariable(tup.items) elif is_frozen_dataclass(self.value) and self.is_standard_new(): fields = dataclasses.fields(self.value) + fields_source = DataclassFieldsSource(self.source) items = list(args) items.extend([None] * (len(fields) - len(items))) default_kwargs = {} - for field, var_tracker in zip(fields, items): + for ind, field, var_tracker in zip(itertools.count(), fields, items): if var_tracker is None: if field.name in kwargs: var_tracker = kwargs[field.name] @@ -637,7 +640,13 @@ def call_function( continue if field.default is not dataclasses.MISSING: - var_tracker = VariableTracker.build(tx, field.default) + var_tracker = VariableTracker.build( + tx, + field.default, + source=AttrSource( + GetItemSource(fields_source, ind), "default" + ), + ) elif field.default_factory is not dataclasses.MISSING: factory_fn = VariableTracker.build( tx, field.default_factory diff --git a/torch/_export/serde/export_schema.thrift b/torch/_export/serde/export_schema.thrift index c4f5809af0b4..50472c02375c 100644 --- a/torch/_export/serde/export_schema.thrift +++ b/torch/_export/serde/export_schema.thrift @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<> +// checksum<<31664e4faa0eacd6f538ffed163078e190d9d2b98d762dd45b68eb1b7b12f0d1>> namespace py3 torch._export namespace cpp2 torch._export.schema @@ -50,6 +50,8 @@ enum ScalarType { UINT16 = 28, FLOAT8E4M3FN = 29, FLOAT8E5M2 = 30, + FLOAT8E4M3FNUZ = 31, + FLOAT8E5M2FNUZ = 32, } diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index 32c69140807b..933d30310b72 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -9,7 +9,7 @@ # NOTE: Please update this value if any modifications are made to the schema -SCHEMA_VERSION = (8, 8) +SCHEMA_VERSION = (8, 9) TREESPEC_VERSION = 1 @@ -33,6 +33,8 @@ class ScalarType(IntEnum): UINT16 = 28 FLOAT8E4M3FN = 29 FLOAT8E5M2 = 30 + FLOAT8E4M3FNUZ = 31 + FLOAT8E5M2FNUZ = 32 class Layout(IntEnum): diff --git a/torch/_export/serde/schema.yaml b/torch/_export/serde/schema.yaml index 19145e7f8e32..9167a6820ef4 100644 --- a/torch/_export/serde/schema.yaml +++ b/torch/_export/serde/schema.yaml @@ -1,5 +1,5 @@ # @generated by update_schema.py -# checksum<<110c364974d3b0f7dcbdf6862781212bdcc7178925c43c894c336fc2b6ca6628>> +# checksum<<5c990535d373dcaa291a4f994b4d7b025e0f8e806ca5268085ef699d0e4d3000>> AOTInductorModelPickleData: kind: struct fields: @@ -420,6 +420,8 @@ ScalarType: UINT16: 28 FLOAT8E4M3FN: 29 FLOAT8E5M2: 30 + FLOAT8E4M3FNUZ: 31 + FLOAT8E5M2FNUZ: 32 SchemaVersion: kind: struct fields: @@ -532,5 +534,5 @@ UserOutputSpec: type: Argument SCHEMA_VERSION: - 8 -- 8 +- 9 TREESPEC_VERSION: 1 diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 5c688b2a14d2..38ccbe287a87 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -143,6 +143,8 @@ def _reverse_map(d: dict[Any, Enum]): torch.bfloat16: ScalarType.BFLOAT16, torch.float8_e4m3fn: ScalarType.FLOAT8E4M3FN, torch.float8_e5m2: ScalarType.FLOAT8E5M2, + torch.float8_e4m3fnuz: ScalarType.FLOAT8E4M3FNUZ, + torch.float8_e5m2fnuz: ScalarType.FLOAT8E5M2FNUZ, } @@ -222,6 +224,31 @@ class _SerializedProgram: example_inputs: bytes +class LazyMap(dict): + """ + Dictionary class for deferred instantiation of node metadata values. + Purpose is to avoid creation of symbolic-shape tensors before relevant shape guards are parsed. + """ + + def __init__(self): + self.map = {} + self.evaluated = set() + + def __setitem__(self, k, v): + self.map[k] = v + + def __getitem__(self, k): + out = self.map[k] + if k in self.evaluated: + return out + self.evaluated.add(k) + self.map[k] = out() + return self.map[k] + + def __repr__(self): + return self.map.__repr__() + + def deserialize_device(d: Device) -> torch.device: if d.index is None: return torch.device(type=d.type) # type: ignore[call-overload] @@ -1669,7 +1696,7 @@ class Result: def __init__(self) -> None: self.serialized_name_to_node: dict[str, torch.fx.Node] = {} - self.serialized_name_to_meta: dict[str, MetaType] = {} + self.serialized_name_to_meta: LazyMap = LazyMap() # str -> MetaType self.graph = torch.fx.Graph() self.module = torch.nn.Module() @@ -1685,7 +1712,7 @@ def save_graph_module(self) -> Iterator[None]: self.graph = torch.fx.Graph() self.module = torch.nn.Module() self.serialized_name_to_node = {} - self.serialized_name_to_meta = {} + self.serialized_name_to_meta = LazyMap() self.unbacked_symbols: set[sympy.Symbol] = set() try: yield @@ -1874,32 +1901,32 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: # Handle the tensor metas. for name, tensor_value in serialized_graph.tensor_values.items(): log.debug("[deserialize_tensor_meta] %s (input): %s", name, tensor_value) - meta_val = self.deserialize_tensor_meta(tensor_value) - log.debug("[deserialize_tensor_meta] %s (output): %s", name, meta_val) - self.serialized_name_to_meta[name] = meta_val + self.serialized_name_to_meta[name] = ( + lambda v=tensor_value: self.deserialize_tensor_meta(v) + ) for name, sym_int_value in serialized_graph.sym_int_values.items(): log.debug("[deserialize_sym_int] %s (input): %s", name, sym_int_value) - int_val = self.deserialize_sym_int(sym_int_value) - log.debug("[deserialize_sym_int] %s (output): %s", name, int_val) - self.serialized_name_to_meta[name] = int_val + self.serialized_name_to_meta[name] = ( + lambda v=sym_int_value: self.deserialize_sym_int(v) + ) for name, sym_float_value in serialized_graph.sym_float_values.items(): log.debug("[deserialize_sym_float] %s (input): %s", name, sym_float_value) - float_val = self.deserialize_sym_float(sym_float_value) - log.debug("[deserialize_sym_float] %s (output): %s", name, float_val) - self.serialized_name_to_meta[name] = float_val + self.serialized_name_to_meta[name] = ( + lambda v=sym_float_value: self.deserialize_sym_float(v) + ) for name, sym_bool_value in serialized_graph.sym_bool_values.items(): log.debug("[deserialize_sym_bool] %s (input): %s", name, sym_bool_value) - bool_val = self.deserialize_sym_bool(sym_bool_value) - log.debug("[deserialize_sym_bool] %s (output): %s", name, bool_val) - self.serialized_name_to_meta[name] = bool_val + self.serialized_name_to_meta[name] = ( + lambda v=sym_bool_value: self.deserialize_sym_bool(v) + ) for name, script_obj_meta in serialized_graph.custom_obj_values.items(): log.debug("[deserialize_script_obj_meta] %s", script_obj_meta) - self.serialized_name_to_meta[name] = self.deserialize_script_obj_meta( - script_obj_meta + self.serialized_name_to_meta[name] = ( + lambda v=script_obj_meta: self.deserialize_script_obj_meta(v) ) log.debug("\n[deserialize graph nodes]") @@ -2078,13 +2105,25 @@ def _is_single_tensor_return(target) -> bool: fx_node.kwargs, fx_node.meta.get("val"), ) + + # handle ShapeEnv asserts + if target == torch.ops.aten._assert_scalar.default: + expr = fx_node.args[0].meta["val"] # type: ignore[union-attr] + if isinstance(expr, torch.SymBool): + self.shape_env.guard_or_defer_runtime_assert( + expr.node.expr, "", fx_node + ) + elif target == torch.ops.aten.sym_constrain_range_for_size.default: + sym = fx_node.args[0].meta["val"] # type: ignore[union-attr] + if isinstance(sym, torch.SymInt): + self.shape_env._constrain_range_for_size(sym.node.expr) + + # handle nn_module_stack; serialization throws away empty dicts if ( fx_node.op not in ["placeholder", "output"] and "nn_module_stack" not in fx_node.meta ): - fx_node.meta[ - "nn_module_stack" - ] = {} # serialization throws away empty dicts + fx_node.meta["nn_module_stack"] = {} def deserialize_input_spec(self, i: InputSpec) -> ep.InputSpec: log.debug("[deserialize_input_spec] %s", i) @@ -2261,8 +2300,6 @@ def deserialize( if symbol_name_to_range: for k, vr in symbol_name_to_range.items(): lower = vr.lower - if vr.upper >= 2: # max is >= 2, not sym bool range - lower = max(2, lower) self.symbol_name_to_range[k] = symbolic_shapes.ValueRanges( _int_to_sympy_int(lower, -int_oo), vr.upper ) diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 1e2f84e5a3bd..5594d72aa2f2 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -8,6 +8,7 @@ import math import operator import re +from collections import defaultdict from collections.abc import Iterable from contextlib import contextmanager from inspect import ismethod, Parameter @@ -255,6 +256,8 @@ def _get_shape_env_from_gm(gm: torch.fx.GraphModule): def _rename_without_collisions( name_map: dict[str, str], + find_available: dict[str, int], + used_names: set[str], orig_name: str, name: str, is_placeholder: bool = False, @@ -262,23 +265,32 @@ def _rename_without_collisions( """ Renames nodes to avoid name collisions, with suffixing. name_map: map from original name to new name + find_available: map prefix to available suffix + used_names: cache of used names orig_name: mapping key name: candidate name (potentially suffixed, e.g. mul_2) is_placeholder: if the node is a placeholder, avoid detecting suffix """ - if name in name_map.values(): - # non-placeholder nodes may be suffixed with the count - # instead of adding another suffix, we will try to increment it - match = re.match(r"(.*)_(\d+)", name) - if match and not is_placeholder: - name, n = match.group(1), int(match.group(2)) - else: - n = 0 - while (dup_name := f"{name}_{n + 1}") in name_map.values(): - n += 1 - name_map[orig_name] = dup_name - else: - name_map[orig_name] = name + match = re.match(r"(.*)_(\d+)", name) + key = name + + if match and not is_placeholder: + prefix, n = match.group(1), match.group(2) + key = prefix + + new_name = name + if new_name in used_names: + new_name = f"{key}_{find_available[key] + 1}" + + match = re.match(r"(.*)_(\d+)", new_name) + if match: + prefix, n = match.group(1), match.group(2) + if int(n) > find_available[prefix]: + find_available[prefix] = int(n) + + name_map[orig_name] = new_name + used_names.add(new_name) + return name_map[orig_name] @@ -867,6 +879,15 @@ def _bind_signature_to_inputs(mod, fake_args, fake_kwargs): return {**sig.bind_partial(*fake_args).arguments, **fake_kwargs} +def _build_cache(name, find_available, used_names): + used_names.add(name) + match = re.match(r"(.*)_(\d+)", name) + if match: + prefix, n = match.group(1), match.group(2) + if int(n) > find_available[prefix]: + find_available[prefix] = int(n) + + def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: """ Propagate placeholder names from the top-level graph into HigherOrderOp subgraphs, @@ -874,6 +895,7 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: Different HOO subgraph types have different input schemas, so we first enumerate them and gather the top-level named placeholder nodes. """ + # gather all HOO subgraphs and their top-level named placeholder nodes subgraph_ph_tuples: list[tuple[torch.fx.GraphModule, list[torch.fx.Node]]] = [] for node in gm.graph.nodes: @@ -897,12 +919,17 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: # propagate names for subgraph, hoo_phs in subgraph_ph_tuples: name_map: dict[str, str] = {} + find_available: dict[str, int] = defaultdict(int) + used_names: set[str] = set() for i, node in enumerate(subgraph.graph.nodes): if i < len(hoo_phs): # placeholder, retain name name_map[node.name] = hoo_phs[i].name node.name = node.target = hoo_phs[i].name + _build_cache(node.name, find_available, used_names) else: # non-placeholder, check for collisions - node.name = _rename_without_collisions(name_map, node.name, node.name) + node.name = _rename_without_collisions( + name_map, find_available, used_names, node.name, node.name + ) # recurse and recompile _name_hoo_subgraph_placeholders(subgraph) @@ -962,6 +989,8 @@ def _extract_pytree_key(x): raise RuntimeError(f"Pytree key of type {type(x)} not handled for {x}") name_map: dict[str, str] = {} + find_available: dict[str, int] = defaultdict(int) + used_names: set[str] = set() # map user input names with mod.forward() signature combined_args = _bind_signature_to_inputs(mod, fake_args, fake_kwargs) @@ -978,6 +1007,8 @@ def _extract_pytree_key(x): if user_input_name: _rename_without_collisions( name_map, + find_available, + used_names, user_input_name, placeholder_prefixes[InputKind.USER_INPUT] + "_".join(_extract_pytree_key(x).lower() for x in arg_path), @@ -997,6 +1028,8 @@ def _extract_pytree_key(x): _rename_without_collisions( name_map, + find_available, + used_names, spec.arg.name, placeholder_prefixes[spec.kind] + base_name, is_placeholder=True, @@ -1015,7 +1048,9 @@ def _extract_pytree_key(x): for node in gm.graph.nodes: if node.op == "placeholder": continue - _rename_without_collisions(name_map, node.name, node.name) + _rename_without_collisions( + name_map, find_available, used_names, node.name, node.name + ) # assign new node names for node in gm.graph.nodes: diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 954dc399f96b..e66ffefe0a00 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -384,6 +384,57 @@ def _reduce_tensor(self, tensor): return (_ident, (metadata,)) +@contextlib.contextmanager +def normalize_placeholder_names(gm: torch.fx.GraphModule): + """ + Context manager that normalizes the placeholder names in the graph module. + This is used while generating a cache key for AOTAutogradCache, so that two graphs + that are isomorphic when normalizing names can hit the same cache entry. + This is safe because nothing underneath AOTAutograd uses the node names on the + original dynamo graph: AOTAutograd re-traces with its own nodes, and guards are + in terms of original sources rather than placeholder names. + """ + # Standalone inductor: we're bypassing AOTAutogradCache anyway, so return the graph + # as-is + if not config.autograd_cache_normalize_inputs or not hasattr(gm, "graph"): + yield + return + + # Track all the old state of placeholders + old_placeholder_names = [] + old_used_names = copy(gm.graph._graph_namespace._used_names) + i = 0 + for n in gm.graph.find_nodes(op="placeholder", sort=True): + if n.type != torch.SymInt: + # _rename renames the node in the body of the function, + # but it doesn't change the raw name from node.target + # So we also set the raw_name of node.target to a new placeholder name + new_placeholder_name = f"p_{i}" + old_placeholder_names.append((n.name, n.target)) + n.target = new_placeholder_name + n._rename(new_placeholder_name) + i += 1 + gm.recompile() + try: + yield + finally: + # Used_names contains all our old placeholder names, + # so we clear it temporarily when we put them back + gm.graph._graph_namespace._used_names = set() + # Restore the placeholder names + i = 0 + for n in gm.graph.find_nodes(op="placeholder", sort=True): + if n.type != torch.SymInt: + (name, target) = old_placeholder_names[i] + n.target = target + n._rename(name) + i += 1 + assert i == len(old_placeholder_names) + # Now restore the old namespace's used names + gm.graph._graph_namespace._used_names = old_used_names + gm.recompile() + + def autograd_cache_key( gm: torch.fx.GraphModule, example_inputs, @@ -407,7 +458,6 @@ def autograd_cache_key( if triton.__version__ < "3.2.0": raise BypassAOTAutogradCache("AOTAutogradCache requires triton 3.2.0") - details = AOTAutogradCacheDetails(gm, example_inputs, config, fx_config) pickler = AOTAutogradCachePickler(gm) # The prefix distinguishes among the other kinds of objects we cache @@ -924,21 +974,22 @@ def sanitize_gm_for_cache(gm: torch.fx.GraphModule): and then put them back before returning. This way, we generate a cache key based off of a canonical graph without these fields, and also guarantee they aren't used to affect the cache's output. """ - IGNORED_FIELDS = ( - "meta", # metadata used by export - "compile_subgraph_reason", # Used by dynamo only for logging, no change in inductor/autograd behavior - "_param_name_to_source", # Encapsulated by aot_config.aot_autograd_arg_pos_to_source - "_backend_id", - ) + # Mapping from each field to a default value + IGNORED_FIELDS: dict[str, Any] = { + "meta": {}, # metadata used by export + "compile_subgraph_reason": None, # Used by dynamo only for logging, no change in inductor/autograd behavior + "_param_name_to_source": None, # Encapsulated by aot_config.aot_autograd_arg_pos_to_source + "_backend_id": None, + } saved_fields = {} - for field in IGNORED_FIELDS: + for field, default_value in IGNORED_FIELDS.items(): saved_fields[field] = getattr(gm, field, None) # Clear the field - setattr(gm, field, None) + setattr(gm, field, default_value) try: - yield + with normalize_placeholder_names(gm): + yield finally: - # Put the fields back after dispatch_and_compile is complete for field, value in saved_fields.items(): setattr(gm, field, value) @@ -1029,8 +1080,7 @@ def clear(): pass @staticmethod - def load( - dispatch_and_compile: Callable, + def try_load( mod: Union[torch.fx.GraphModule, torch._dynamo.utils.GmWrapper], args, aot_config: AOTConfig, @@ -1038,7 +1088,7 @@ def load( boxed_forward_device_index: Optional[BoxedDeviceIndex], local: bool, remote: bool, - ) -> Callable: + ) -> Optional[Callable]: """ Load a result from the cache, and reconstruct a runtime wrapper around the object """ @@ -1147,7 +1197,6 @@ def load( time.time_ns(), forward_symints=symints, ) - compiled_fn = dispatch_and_compile() cache_info.update( { @@ -1181,6 +1230,7 @@ def load( }, payload_fn=lambda: json.dumps(cache_info), ) + return compiled_fn @classmethod diff --git a/torch/_functorch/_aot_autograd/frontend_utils.py b/torch/_functorch/_aot_autograd/frontend_utils.py new file mode 100644 index 000000000000..55b84c12df82 --- /dev/null +++ b/torch/_functorch/_aot_autograd/frontend_utils.py @@ -0,0 +1,284 @@ +# mypy: ignore-errors + +from collections.abc import KeysView +from contextlib import contextmanager +from typing import Any, Optional + +import torch +import torch.utils._pytree as pytree +from torch._guards import detect_fake_mode +from torch._subclasses import FakeTensor, FakeTensorMode +from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from .. import config +from .schemas import AOTConfig, FakifiedFlatArgs + + +static_inputs_log = torch._logging.getArtifactLogger( + __name__, "cudagraph_static_inputs" +) + + +def process_inputs( + flat_args: list[Any], + aot_config: AOTConfig, + fake_mode: FakeTensorMode, + shape_env: Optional[ShapeEnv], + ignore_shape_env: bool = False, +) -> FakifiedFlatArgs: + with fake_mode: + + def convert(idx, x): + if shape_env is not None and not ignore_shape_env: + from torch._dynamo.source import ConstantSource + + if isinstance(x, int): + # We always specialize on scalar values in export. + if aot_config.is_export: + return x + source = ConstantSource(f"sym_{idx}") + return shape_env.create_symintnode( + shape_env.create_symbol(x, source), hint=x, source=source + ) + if isinstance(x, torch.ScriptObject): + return torch._library.fake_class_registry.maybe_to_fake_obj( + fake_mode, x + ) + if not isinstance(x, torch.Tensor): + return x + if isinstance(x, FakeTensor): + assert x.fake_mode is fake_mode + return x + if is_traceable_wrapper_subclass(x): + attrs, _ = x.__tensor_flatten__() + if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs): + assert all( + getattr(x, attr).fake_mode is fake_mode for attr in attrs + ) + return x + + # see note [Tensor Fakification and Symbol Caching] + symbolic_context = None + source = None + trace = True + if tracing_context := torch._guards.TracingContext.try_get(): + if x in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[x] + source = symbolic_context.tensor_source + # We already fakeified this tensor in Dynamo, don't + # dump the trace for it again + trace = False + if ( + idx < aot_config.num_params_buffers + and config.static_weight_shapes + and not symbolic_context + ): + # TODO: Ensure that this codepath is never exercised from + # Dynamo + return fake_mode.from_tensor(x, static_shapes=True) + + result = fake_mode.from_tensor( + x, + static_shapes=ignore_shape_env, + symbolic_context=symbolic_context, + source=source, + trace=trace, + ) + return result + + return FakifiedFlatArgs([convert(idx, x) for idx, x in enumerate(flat_args)]) + + +def construct_fake_mode( + flat_args: list[Any], aot_config: AOTConfig +) -> tuple[FakeTensorMode, Optional[ShapeEnv]]: + fake_mode = detect_fake_mode(flat_args) + if fake_mode is None: + shape_env = ShapeEnv() if aot_config.dynamic_shapes else None + fake_mode = FakeTensorMode(shape_env=shape_env) + else: + shape_env = fake_mode.shape_env + return (fake_mode, shape_env) + + +def _try_get_metadata_from_dynamo( + mod: torch.nn.Module, param_keys: KeysView[str], full_args_num: int +) -> tuple[Optional[list[torch._guards.Source]], list[int]]: + """ + Metadata is forwarded from Dynamo to AOTDispatch via special fields on GraphModule. + We first verify that `mod` does come from Dynamo, then we handle cases where + metadata might be missing. + + Returns: + aot_autograd_arg_pos_to_source: used to dedup params and their guards + static_input_indices: used to identify static inputs for cudagraphs + """ + # Note [Assumption on Dynamo Metadata] + # This function assumes a graph module from dynamo provides `dynamo_compiled_id`, + # _param_name_to_source, and every placeholder node has `_dynamo_source` attributes. + # When gm is modified (e.g., DDPOptimizer via split_module), metadata needs to + # be propagated in order to be recognized as a dynamo graph + + if not (isinstance(mod, torch.fx.GraphModule) and "dynamo_compile_id" in mod.meta): + # graph was not captured by dynamo + return None, [] + + if not hasattr(mod, "_param_name_to_source"): + # is from export + return None, [] + + # We now know this came from dynamo, and (1) we care about guards, + # so setting up aot_autograd_arg_pos_to_source for downstream dedup guards + # can now be done safely. (2) Dynamo logic protects the 1:1 sizing below. + # Additionally, we mark static indices for cudagraphs. + param_name_to_source = mod._param_name_to_source + seen_sources = set() + + aot_autograd_arg_pos_to_source = [] + static_input_indices = [] + # Collect the new inputs lifted by aotdispatch + for i, name in enumerate(param_keys): + assert name in param_name_to_source, f"{name} not found." + source = param_name_to_source[name] + assert source not in seen_sources, source + seen_sources.add(source) + aot_autograd_arg_pos_to_source.append(source) + + static_input_indices.append(i) + + # Collect the dynamo graph inputs + # TODO(mlazos): Revisit if this is still needed. With Dynamo install ID + # matched tensors back into the Fx graph, this might not be necessary. + for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")): + assert hasattr(node, "_dynamo_source") + source = node._dynamo_source + # `source`` specifies the source from user code. ddp optimizer may have + # intermediate values becoming submodule placeholders which does not + # have a source + assert source is None or source not in seen_sources, source + seen_sources.add(source) + aot_autograd_arg_pos_to_source.append(source) + source_name = source.name() if source else str(source) + + # input[i] in dynamo is now: + # input[i + len(extra_params)] in AOT, + # where extra_params are the params/buffers that dynamo baked into the + # OutputGraph + actual_pos = pos + len(param_keys) + + if "tensor_dict" in node.meta and node.meta["tensor_dict"].get( + "_dynamo_static_input_type", None + ): + static_inputs_log.debug( + "Adding static input pos %s for source %s", actual_pos, source_name + ) + static_input_indices.append(actual_pos) + else: + static_inputs_log.debug( + "Non-static input pos %s for source %s", actual_pos, source_name + ) + + assert full_args_num == len(aot_autograd_arg_pos_to_source) + return aot_autograd_arg_pos_to_source, static_input_indices + + +@contextmanager +def _detect_attribute_assignment(mod: torch.nn.Module): + # Do not allow assignment of tensor attributes during export unless + # the attribute is registered as a buffer. + + NN_MODULE_STD_ATTRS = [ + "_backward_hooks", + "_backward_pre_hooks", + "_buffers", + "_forward_hooks", + "_forward_hooks_always_called", + "_forward_hooks_with_kwargs", + "_forward_pre_hooks", + "_forward_pre_hooks_with_kwargs", + "_is_full_backward_hook", + "_load_state_dict_post_hooks", + "_load_state_dict_pre_hooks", + "_modules", + "_non_persistent_buffers_set", + "_parameters", + "_state_dict_hooks", + "_state_dict_pre_hooks", + "training", + ] + NN_MODULE_LAZY_STD_ATTRS = [ + "_initialize_hook", + "_load_hook", + ] + STD_ATTRS = { + *NN_MODULE_STD_ATTRS, + *NN_MODULE_LAZY_STD_ATTRS, + } + + def _get_attributes(mod): + # return any attributes of a module that are not standard attributes + return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS} + + # save state of attributes before enter + snapshot = pytree.tree_map( + lambda x: x, + _get_attributes(mod), + is_leaf=lambda x: type(x) in _pytree_subclasses_that_lose_info, + ) + try: + yield + finally: + # after exit, compare state of attributes with snapshot + # to detect which tensor attributes were assigned + assigned_tensor_attributes = [] + + def _collect_assigned_tensor_attributes(kp, v, _v): + if _v is not v: + attr, *rest = kp + if isinstance(v, torch.Tensor): + assigned_tensor_attributes.append( + f"self.{attr.key}{pytree.keystr(rest)}" + ) + # TODO(avik): Assigning all other types are allowed right now. + # Maybe in the future we want to limit this to primitive types? + return v + + new_attrs = _get_attributes(mod) + if len(new_attrs) != len(snapshot): + added_attrs = new_attrs.keys() - snapshot.keys() + deleted_attrs = snapshot.keys() - new_attrs.keys() + + if len(added_attrs) > 0: + raise ValueError( + f"During torch.export, following attrs were created in the model.forward: {added_attrs} " + f"Such attributes must be registered as buffers using the `register_buffer` " + f"API and must be initialized at model.__init__ " + f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." + ) + + if len(deleted_attrs) > 0: + raise ValueError( + f"During torch.export, following attrs were deleted in the model.forward: {deleted_attrs} " + f"Such attributes must be registered as buffers using the `register_buffer` " + f"API and must be initialized at model.__init__ " + f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." + ) + + pytree.tree_map_with_path( + _collect_assigned_tensor_attributes, snapshot, new_attrs + ) + # restore state of all attributes (including, e.g., of primitive types) + mod.__dict__.update(snapshot) + + if assigned_tensor_attributes: + if len(assigned_tensor_attributes) > 1: + noun, verb = "attributes", "were" + else: + noun, verb = "attribute", "was" + raise ValueError( + f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. " + "Such attributes must be registered as buffers using the `register_buffer` API " + "(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." + ) diff --git a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py b/torch/_functorch/_aot_autograd/graph_capture.py similarity index 99% rename from torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py rename to torch/_functorch/_aot_autograd/graph_capture.py index be3226ca01f5..f4710bc8000c 100644 --- a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py +++ b/torch/_functorch/_aot_autograd/graph_capture.py @@ -23,8 +23,7 @@ assert_functional_graph, propagate_input_mutation_stacktraces, ) -from .schemas import AOTConfig, SubclassMeta, ViewAndMutationMeta -from .traced_function_transforms import ( +from .graph_capture_wrappers import ( aot_dispatch_subclass, create_functionalized_fn, create_joint, @@ -32,6 +31,7 @@ fn_prepped_for_autograd, handle_effect_tokens_fn, ) +from .schemas import AOTConfig, SubclassMeta, ViewAndMutationMeta from .utils import ( copy_fwd_metadata_to_bw_nodes, register_buffer_assignment_hook, diff --git a/torch/_functorch/_aot_autograd/traced_function_transforms.py b/torch/_functorch/_aot_autograd/graph_capture_wrappers.py similarity index 100% rename from torch/_functorch/_aot_autograd/traced_function_transforms.py rename to torch/_functorch/_aot_autograd/graph_capture_wrappers.py diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/graph_compile.py similarity index 95% rename from torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py rename to torch/_functorch/_aot_autograd/graph_compile.py index 53bfa1e3c51e..cc64c82c2920 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/graph_compile.py @@ -18,13 +18,18 @@ import traceback from collections import defaultdict from contextlib import nullcontext -from typing import Any, Callable, Optional, TYPE_CHECKING +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch import torch.utils._pytree as pytree import torch.utils.dlpack from torch import Tensor -from torch._dynamo.utils import detect_fake_mode, dynamo_timed, lazy_format_graph_code +from torch._dynamo.utils import ( + CompileEventLogger, + detect_fake_mode, + dynamo_timed, + lazy_format_graph_code, +) from torch._guards import CompileContext, TracingContext from torch._logging import getArtifactLogger, trace_structured from torch._subclasses import FakeTensor @@ -46,10 +51,7 @@ should_bundle_autograd_cache, should_use_remote_autograd_cache, ) -from .dispatch_and_compile_graph import ( - aot_dispatch_autograd_graph, - aot_dispatch_base_graph, -) +from .graph_capture import aot_dispatch_autograd_graph, aot_dispatch_base_graph from .logging_utils import track_graph_compiling from .runtime_wrappers import ( AOTDedupeWrapper, @@ -67,7 +69,13 @@ pre_compile, RuntimeWrapper, ) -from .schemas import AOTConfig, MutationType, ViewAndMutationMeta +from .schemas import ( + AOTConfig, + AOTGraphCapture, + AOTState, + MutationType, + ViewAndMutationMeta, +) from .subclass_utils import compute_inner_mutated_inp_indices_from_subclass_meta from .utils import ( _get_symint_hints, @@ -92,6 +100,7 @@ # Returns a Callable and a ViewAndMutationMeta. # Currently, only export needs the ViewAndMutationMeta after this function. +# TODO: Refactor this DispatchReturn = tuple[Callable, ViewAndMutationMeta] @@ -102,46 +111,68 @@ def _create_wrappers_for_dispatch(needs_autograd: bool) -> list[CompilerWrapper] return [AOTDedupeWrapper(), AOTSyntheticBaseWrapper(trace_joint=needs_autograd)] -# Export's dispatching logic is unique in a few ways: it only needs the "graph" -# bits of aot_autograd, and doesn't need to do any specific wrapping. -def aot_dispatch_export( +def aot_stage1_graph_capture( + aot_state: AOTState, flat_fn: Callable, - flat_args: list[Any], - aot_config: AOTConfig, - *, - fw_metadata: ViewAndMutationMeta, - needs_autograd: bool, -) -> DispatchReturn: - wrappers = _create_wrappers_for_dispatch(needs_autograd) - flat_fn, flat_args, fw_metadata = pre_compile( +) -> AOTGraphCapture: + aot_config = aot_state.aot_config + + wrappers = _create_wrappers_for_dispatch(aot_state.needs_autograd) + flat_fn, aot_state.flat_args, aot_state.fw_metadata = pre_compile( wrappers, flat_fn, - flat_args, + aot_state.flat_args, aot_config, - fw_metadata=fw_metadata, + fw_metadata=aot_state.fw_metadata, ) - if needs_autograd and not aot_config.pre_dispatch: - graph, _, _ = aot_dispatch_autograd_graph( - flat_fn, flat_args, aot_config, fw_metadata=fw_metadata - ) + # NB: This is currently only used for backwards, where fwd/bwd + # deterministic TLS can be different + aot_state.fw_metadata.deterministic = torch.are_deterministic_algorithms_enabled() + updated_flat_args: Union[list[Any], tuple[list[Any], list[Any]]] + if aot_state.needs_autograd and not aot_config.pre_dispatch: + # FYI: this being moved to trigger in export is new, seems fine! + with dynamo_timed("aot_trace_joint_graph", log_pt2_compile_event=True): + graph, updated_flat_args, maybe_subclass_meta = aot_dispatch_autograd_graph( + flat_fn, + aot_state.flat_args, + aot_config, + fw_metadata=aot_state.fw_metadata, + ) else: - graph, _, _ = aot_dispatch_base_graph( - flat_fn, flat_args, aot_config, fw_metadata=fw_metadata + graph, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph( + flat_fn, aot_state.flat_args, aot_config, fw_metadata=aot_state.fw_metadata ) + return AOTGraphCapture( + wrappers=wrappers, + graph=graph, + updated_flat_args=updated_flat_args, + maybe_subclass_meta=maybe_subclass_meta, + ) + + +def aot_stage2_export( + aot_state: AOTState, aot_graph_capture: AOTGraphCapture +) -> DispatchReturn: + graph = aot_graph_capture.graph + aot_config = aot_state.aot_config + wrappers = aot_graph_capture.wrappers + + CompileEventLogger.try_add_pt2_compile("backend_compile", dispatch_mode="export") + # NB: the wrappers that run in pre_compile for export are # either a no-op, because they're not needed, or will raise a runtime error, # since they don't support export. # We still run these wrappers to make sure that they're not needed pre compile, # but we technically don't need to run them post compile at all here. - compiled_fn, fw_metadata = post_compile( - wrappers, graph, aot_config, runtime_metadata=fw_metadata + compiled_fn, aot_state.fw_metadata = post_compile( + wrappers, graph, aot_config, runtime_metadata=aot_state.fw_metadata ) # Therefore, since no wrapperes run, we don't get back a callable - we get back the raw fx graph # (either a joint or an inference-only graph) assert isinstance(compiled_fn, torch.fx.GraphModule) - return compiled_fn, fw_metadata + return compiled_fn, aot_state.fw_metadata def sanitize_aot_config(input: AOTConfig) -> AOTConfig: @@ -166,23 +197,33 @@ def sanitize_aot_config(input: AOTConfig) -> AOTConfig: ) -def aot_dispatch_base( - flat_fn, - flat_args: list[Any], - aot_config: AOTConfig, - *, - fw_metadata: ViewAndMutationMeta, +def aot_stage2_compile( + aot_state: AOTState, + aot_graph_capture: AOTGraphCapture, +) -> DispatchReturn: + if aot_state.needs_autograd and not aot_state.aot_config.pre_dispatch: + return aot_stage2_autograd(aot_state, aot_graph_capture) + else: + return aot_stage2_inference(aot_state, aot_graph_capture) + + +def aot_stage2_inference( + aot_state: AOTState, + aot_graph_capture: AOTGraphCapture, ) -> DispatchReturn: """ Handles functions that don't need autograd. Runs wrappers and compiles with fw_compiler. """ - wrappers = _create_wrappers_for_dispatch(needs_autograd=False) - flat_fn, flat_args, fw_metadata = pre_compile( - wrappers, flat_fn, flat_args, aot_config, fw_metadata=fw_metadata - ) - fw_module, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph( # type: ignore[misc] - flat_fn, flat_args, aot_config, fw_metadata=fw_metadata - ) + + aot_config = aot_state.aot_config + fw_metadata = aot_state.fw_metadata + fw_module = aot_graph_capture.graph + wrappers = aot_graph_capture.wrappers + updated_flat_args = aot_graph_capture.updated_flat_args + maybe_subclass_meta = aot_graph_capture.maybe_subclass_meta + + CompileEventLogger.try_add_pt2_compile("backend_compile", dispatch_mode="inference") + # Save the forward_graph_str right after aot_dispatch_base_graph, # to save in the cache aot_forward_graph_str = None @@ -195,19 +236,11 @@ def aot_dispatch_base( ) fakified_out_wrapper = FakifiedOutWrapper() - ( - fw_module, - updated_flat_args, - fw_metadata, - ) = fakified_out_wrapper.pre_compile( + fakified_out_wrapper.pre_compile( fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata ) functionalized_rng_wrapper = FunctionalizedRngRuntimeWrapper() - ( - fw_module, - updated_flat_args, - fw_metadata, - ) = functionalized_rng_wrapper.pre_compile( + functionalized_rng_wrapper.pre_compile( fw_module, updated_flat_args, aot_config, fw_metadata=fw_metadata ) assert isinstance(fw_module, GraphModule) @@ -852,7 +885,7 @@ def _wrapper(*args): def prepare_hook_gm(aot_config, fn, args): - from torch._functorch._aot_autograd.dispatch_and_compile_graph import _create_graph + from torch._functorch._aot_autograd.graph_capture import _create_graph fn, args = create_wrap_fn(fn, args) gm = _create_graph(fn, args, aot_config=aot_config) @@ -1247,31 +1280,23 @@ def _log_structured_logs(): bw_module.recompile() -def aot_dispatch_autograd( - flat_fn, - flat_args: list[Any], - aot_config: AOTConfig, - *, - fw_metadata: ViewAndMutationMeta, +def aot_stage2_autograd( + aot_state: AOTState, aot_graph_capture: AOTGraphCapture ) -> DispatchReturn: """ Autograd logic. Generates a joint graph, partitions it, manipulates the input with various wrappers, and returns a wrapped torch.autograd.Function with a forward and backward. """ - wrappers = _create_wrappers_for_dispatch(needs_autograd=True) - flat_fn, flat_args, fw_metadata = pre_compile( - wrappers, - flat_fn, - flat_args, - aot_config, - fw_metadata=fw_metadata, - ) - fw_metadata.deterministic = torch.are_deterministic_algorithms_enabled() - with dynamo_timed("aot_trace_joint_graph", log_pt2_compile_event=True): - fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph( - flat_fn, flat_args, aot_config, fw_metadata=fw_metadata - ) + wrappers = aot_graph_capture.wrappers + fx_g = aot_graph_capture.graph + flat_args = aot_state.flat_args + joint_inputs = aot_graph_capture.updated_flat_args + maybe_subclass_meta = aot_graph_capture.maybe_subclass_meta + aot_config = aot_state.aot_config + fw_metadata = aot_state.fw_metadata + + CompileEventLogger.try_add_pt2_compile("backend_compile", dispatch_mode="autograd") # Copied from aot_dispatch_autograd_graph. disable_amp = torch._C._is_any_autocast_enabled() @@ -1579,11 +1604,7 @@ def aot_dispatch_autograd( adjusted_flat_args = joint_inputs[0] fakified_out_wrapper = FakifiedOutWrapper() - ( - fw_module, - adjusted_flat_args, - fw_metadata, - ) = fakified_out_wrapper.pre_compile( + fakified_out_wrapper.pre_compile( fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata ) @@ -1600,11 +1621,7 @@ def aot_dispatch_autograd( ] adjusted_flat_args.extend(rng_states) # type: ignore[arg-type] - ( - fw_module, - adjusted_flat_args, - fw_metadata, - ) = functionalized_rng_wrapper.pre_compile( + functionalized_rng_wrapper.pre_compile( fw_module, adjusted_flat_args, aot_config, fw_metadata=fw_metadata ) if tracing_context := torch._guards.TracingContext.try_get(): diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 77eebd5e6248..805bb5d79c8a 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -41,6 +41,7 @@ from .. import config from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata from .functional_utils import gen_alias_from_base +from .graph_capture_wrappers import aot_dispatch_subclass from .input_output_analysis import ( compute_overlapping_inputs, create_synthetic_base_metadata, @@ -49,6 +50,8 @@ from .logging_utils import describe_input, format_guard_bug_msg, track_graph_compiling from .schemas import ( AOTConfig, + CompilerWrapper, + InductorWrapper, InputAliasInfo, MemoryFormatMeta, MutationType, @@ -64,7 +67,6 @@ runtime_unwrap_tensor_subclasses, wrap_tensor_subclasses, ) -from .traced_function_transforms import aot_dispatch_subclass from .utils import ( call_func_at_runtime_with_args, make_boxed_func, @@ -80,54 +82,6 @@ zip = strict_zip -class CompilerWrapper: - """ - A wrapper around the inputs and outputs to the compiler_fn. We separate these into two parts: - - 1. The prologue, which edits the input to the compiler_fn(flat_fn, flat_args, etc) - 2. The epilogue, which edits the outputs of the compiler_fn (compiled_fn, real arguments) - - Each wrapper below should be implemented as a CompilerWrapper, so that we can facilitate - caching on the compiled output, and re-wrapping the output via epilogues. - Extra metadata that is needed to compute pre or post compile can be passed in via attributes. - """ - - def pre_compile( - self, - flat_fn, - flat_args: list[Tensor], - aot_config: AOTConfig, - *, - fw_metadata: ViewAndMutationMeta, - ) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: - """ - Process the inputs to the compiler_fn. You can pass in extra metadata via kwargs. - Args: - flat_fn: The function to compile - flat_args: Metadata from example inputs of the function to compile - aot_config: AOTConfig passed in at compile time - fw_metadata: ViewAndMutationMeta generated from flat_fn and flat_args - """ - return flat_fn, flat_args, fw_metadata - - def post_compile(self, compiled_fn, aot_config, *, runtime_metadata) -> Callable: - """ - Given an output of the compiler, wrap it with information received from prologue. - Args: - compiled_fn: Callable after calling compiler_fn - aot_config: AOTConfig after calling prologue - runtime_metadata: ViewAndMutationMeta after calling all wrappers's pre_compile steps. - Example: - - def wrapped_compiled_fn(args): - # do something with args, aot_config, fw_metadata - return compiled_fn(args) - - return wrapped_compiled_fn - """ - return compiled_fn - - # The wrapper created by this function handles all of the runtime aliasing and mutation "epilogue" logic # that needs to run after the compiled function. # @@ -505,7 +459,7 @@ def _runtime_wrapper(*args, **kwargs): @dataclass -class FunctionalizedRngRuntimeWrapper(CompilerWrapper): +class FunctionalizedRngRuntimeWrapper(InductorWrapper): # TODO: I would love to get rid of this argument, but it's # Wrapped pretty tightly around our aot_dispatch_autograd logic. # Specifically, tensors_saved_for_backwards_slice's value is both used for calculating indices @@ -517,12 +471,12 @@ class FunctionalizedRngRuntimeWrapper(CompilerWrapper): def pre_compile( self, - flat_fn, + flat_fn: torch.fx.GraphModule, flat_args, aot_config, *, fw_metadata, - ) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: + ) -> None: if config.functionalize_rng_ops: # Update example inputs for the fw_compiler fake_mode = detect_fake_mode() @@ -531,7 +485,6 @@ def pre_compile( # We are not clearing flat_args here because # 1) There is a check in the debug compiler at the end # 2) It does not matter as these are fake tensors - return flat_fn, flat_args, fw_metadata def post_compile( self, @@ -580,7 +533,7 @@ def _functionalized_rng_runtime_epilogue( @dataclass -class FakifiedOutWrapper(CompilerWrapper): +class FakifiedOutWrapper(InductorWrapper): out_metas: list[torch.Tensor] = field(default_factory=list) # TracingContext.fwd_output_strides # Generated from actually doing compile @@ -595,7 +548,7 @@ def pre_compile( aot_config, *, fw_metadata, - ) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: + ) -> None: tracing_context = torch._guards.TracingContext.try_get() if tracing_context and tracing_context.fakify_first_call: self.out_metas = [ @@ -603,7 +556,6 @@ def pre_compile( ] else: self.needs_post_compile = False - return fw_module, flat_args, fw_metadata def _compute_output_meta_with_inductor_strides(self): out = self.out_metas diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index 9b3239823303..efb16234c20c 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -4,19 +4,28 @@ input/output types, metadata, config, function signatures etc. """ +from __future__ import annotations + import collections import dataclasses import functools import itertools -from collections.abc import Iterable, Sequence from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, NewType, Optional, Union +from typing import ( + Any, + Callable, + NewType, + Optional, + Protocol, + TYPE_CHECKING, + TypeVar, + Union, +) import torch import torch.utils._pytree as pytree -from torch._guards import Source -from torch._ops import OpOverload +from torch import Tensor from torch._subclasses import FakeTensor from torch._subclasses.fake_tensor import is_fake from torch.utils._python_dispatch import is_traceable_wrapper_subclass @@ -29,6 +38,16 @@ from .utils import strict_zip +if TYPE_CHECKING: + import contextlib + from collections.abc import Iterable, Sequence + + from torch._guards import Source + from torch._inductor.output_code import OutputCode + from torch._inductor.utils import InputType + from torch._ops import OpOverload + + zip = strict_zip @@ -166,7 +185,7 @@ class MemoryFormatMeta: memory_format: Optional[torch.memory_format] = None @staticmethod - def from_tensor(t: torch.Tensor) -> Optional["MemoryFormatMeta"]: + def from_tensor(t: torch.Tensor) -> Optional[MemoryFormatMeta]: # We only memorize expected memory format for # 1. Traceable wrapper subclasses # We can not create restrided subclass tensor, as torch.empty_strided works only with dense tensors. @@ -232,7 +251,7 @@ class SubclassCreationMeta: # meta and attrs are produced by the subclass's __tensor_flatten__. # We need to keep them around along with outer_size / outer_stride to plumb them # into __tensor_unflatten__ - attrs: dict[str, Union["SubclassCreationMeta", PlainTensorMeta]] + attrs: dict[str, Union[SubclassCreationMeta, PlainTensorMeta]] outer_size: Iterable[Union[None, int, torch.SymInt]] outer_stride: Iterable[Union[None, int, torch.SymInt]] meta: Any @@ -828,7 +847,7 @@ def from_tracing_metadata( num_user_outputs: int, loss_index: Optional[int], backward_signature: Optional[BackwardSignature], - ) -> "GraphSignature": + ) -> GraphSignature: graph_inputs = graph_input_names graph_outputs = graph_output_names parameters = list(named_parameters) @@ -966,3 +985,244 @@ def __post_init__(self): "SubclassTracingInfo", ["plain_tensor_trace_fn", "plain_tensor_args", "maybe_subclass_meta"], ) + + +@dataclass +class AOTState: + """ + When we run AOTAutograd, this class encapsulates the state in the compiler which + must be preserved across stages. This is state in the traditional sense (not an + environment) because some values in this structure change as we progress through + pipelines in AOTAutograd. + """ + + # Whether or not we need to handle autograd when doing graph capture and + # compilation. Although the calling convention for non-autograd graph + # capture in AOTAutograd is simple and can be relied upon, the autograph + # capture calling convention is quite complicated and in general you are + # only expected to pass to aot_stage2_compile to process. + needs_autograd: bool + + # The FAKE flat arguments which we will do tracing with. Although you + # might naively expect this to be immutable, it's not: when we perform + # tracing, we may execute code that modifies the metadata of inputs, + # causing the args to become "invalid". It's also nontrivial to have a + # "golden" set of fake values and deepcopy them just in time when you + # might destructively mutate them (Voz and I tried very hard to do this). + # So we just periodically renew this field. Don't worry too much about + # this unless you're specifically trying to track down an input metadata + # mutation bug. + # + # (By the way, this is NEVER the joint inputs! Those only ever go in + # AOTGraphCapture) + flat_args: list[Any] + + # This contains view and mutation information about the function, which we + # detected by doing an initial trace when we created this state. + fw_metadata: ViewAndMutationMeta + + # Top-level configuration + # This is morally immutable but sometimes we are naughty and mutate it. + aot_config: AOTConfig + + # When performing AOTAutograd traces and other passes, we typically + # require a lot of active context managers; most typically these either + # (1) ensure we are faithfully replicating the original PyTorch context + # managers or (2) toggle some behaviors in PyTorch to make it more + # suitable for tracing. When you use AOTState, you're expected to have + # created an ExitStack, entered it; then while we are running AOTAutograd + # we will add things onto the stack as necessary. When you're all done + # with processing AOTAutograd, you can exit this stack. All functions + # that take AOTState expect the ExitStack to not have been exited yet. + # + # TODO: We potentially could offer a resumable context manager, where you + # can cancel it and reenable it later when you need it. + stack: contextlib.ExitStack + + +class CompilerWrapper: + """ + AOTAutograd needs to do many transformations to the calling convention of the user function + it is tracing, e.g., deduplicating inputs, unpacking subclasses, etc. CompilerWrapper lets + us factor these into compositional stages so we can handle each transformation incrementally + instead of having to do it all at once. + + Since there is a calling convention change, there are two parts to the wrpaper: + + 1. The prologue, which is about compile-time behavior: given this original function, what + is the new function with modified calling convention that we should trace with AOTAutograd + to get the FX graph we will do joint passes, partitioning and ultimate Inductor compilation on? + We get (flat_fn, flat_args), the original function under trace and inputs we were + going to feed it, and produce a new function and new inputs to feed it. + + 2. The epilogue, which is about run-time behavior: we have now compiled the modified calling + convention function, we need to wrap it so that we have a new function that has the + original calling convention of the original function, so that our users can call it + at the old signature they expected. We get (compiled_fn, real arguments), the newly + compiled function we need to wrap. + + Note about caching: we do NOT directly serialize the runtime wrappers; instead, they + are reapplied to compiled_fn after we have finished deserializing the compiled_fn. + + Extra metadata that is needed to compute pre or post compile can be passed in via attributes. + """ + + def pre_compile( + self, + flat_fn, + flat_args: list[Tensor], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ) -> tuple[Callable, list[Tensor], ViewAndMutationMeta]: + """ + Process the inputs to the compiler_fn. You can pass in extra metadata via kwargs. + Args: + flat_fn: The function to compile + flat_args: Metadata from example inputs of the function to compile + aot_config: AOTConfig passed in at compile time + fw_metadata: ViewAndMutationMeta generated from flat_fn and flat_args + """ + return flat_fn, flat_args, fw_metadata + + def post_compile(self, compiled_fn, aot_config, *, runtime_metadata) -> Callable: + """ + Given an output of the compiler, wrap it with information received from prologue. + Args: + compiled_fn: Callable after calling compiler_fn + aot_config: AOTConfig after calling prologue + runtime_metadata: ViewAndMutationMeta after calling all wrappers's pre_compile steps. + Example: + + def wrapped_compiled_fn(args): + # do something with args, aot_config, fw_metadata + return compiled_fn(args) + + return wrapped_compiled_fn + """ + return compiled_fn + + +class InductorWrapper: + """ + This is sort of like CompilerWrapper, but it happens at a different part of the lifecycle: + it talks about transformations we do to the traced and partitioned FX graph before we + send it to the Inductor compiler. + + Once again, there are two parts: + + 1. The prologue, which "modifies" the FX graph before we send it to + Inductor. I say "modifies" because... we don't really actually do + anything nontrivial in either of our two implementations. + 2. The epilogue, which modifies the compiled function produced by Inductor + + Although hypothetically these wrappers could be used compositionally in a centralized + wrappers list, in practice they seem to just be invoked manually when needed. + + NB: The flat_args input is sometimes mutated. This is probably naughty but whatever. + """ + + def pre_compile( + self, + fw_module: torch.fx.GraphModule, + flat_args: list[Tensor], + aot_config: AOTConfig, + *, + fw_metadata: ViewAndMutationMeta, + ) -> None: + """ + Process the inputs to the compiler_fn. You can pass in extra metadata via kwargs. + Args: + flat_fn: The function to compile + flat_args: Metadata from example inputs of the function to compile + aot_config: AOTConfig passed in at compile time + fw_metadata: ViewAndMutationMeta generated from flat_fn and flat_args + """ + return + + def post_compile(self, compiled_fn, aot_config, *, runtime_metadata) -> Callable: + """ + Given an output of the compiler, wrap it with information received from prologue. + Args: + compiled_fn: Callable after calling compiler_fn + aot_config: AOTConfig after calling prologue + runtime_metadata: ViewAndMutationMeta after calling all wrappers's pre_compile steps. + Example: + + def wrapped_compiled_fn(args): + # do something with args, aot_config, fw_metadata + return compiled_fn(args) + + return wrapped_compiled_fn + """ + return compiled_fn + + +@dataclass +class AOTGraphCapture: # Produced by aot_stage1_graph_capture + # AOTAutograd typically operates by taking complicated graphs and + # desugaring them into simpler graphs that use PyTorch features. These + # wrappers establish invariants so that when we actually do tracing we can + # assume these invariants hold, leading to a simpler tracing + # implementation. However, this means that we have to keep track of how + # to enter/exit these wrappers when passing inputs into the compiled + # graph, among other things! + wrappers: list[CompilerWrapper] + + # The actual captured graph. In some circumstances (export) this graph + # has a specific calling convention that can be relied upon by external + # callers. In other situations, the calling convention is unspecified and + # only aot_stage2_compile knows how to deal with them. + graph: torch.fx.GraphModule + + # When compiling with autograd support, this is the joint_inputs, which is + # larger than the original flat_args as all tangents get inputs. The + # tuple organizes into primals and tangents. When not autograd it's just + # a plain list. + updated_flat_args: Union[list[Any], tuple[list[Any], list[Any]]] + + # Metadata about subclass inputs/outputs in the graph trace. + maybe_subclass_meta: Any + + +FakifiedFlatArgs = NewType("FakifiedFlatArgs", list[Any]) + + +TOutputCode = TypeVar("TOutputCode", bound="OutputCode") + + +class AOTDispatchCompiler(Protocol): + """ + Represents a fw or bw_compiler passed to AOTAutograd. + """ + + def __call__( + self, + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + ) -> Any: ... + + +# TODO: bikeshed on this name +class SerializableAOTDispatchCompiler(AOTDispatchCompiler): + """ + Represents an AOTDispatchCompiler that returns an OutputCode, and is + therefore cacheable. SerializableAOTDispatchCompiler always return an OutputCode. + A _CompileFxCallable usually gets converted into an AOTDispatchCompiler after binding all of + the kwargs in _CompileFxKwargs. + """ + + def __init__( + self, + output_code_ty: type[TOutputCode], + compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode], + ): + self.output_code_ty = output_code_ty + self.compiler_fn = compiler_fn + + def __call__( + self, + gm: torch.fx.GraphModule, + example_inputs: Sequence[InputType], + ) -> OutputCode: + return self.compiler_fn(gm, example_inputs) diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 56367c0c4676..495193c89f61 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -1,10 +1,10 @@ # mypy: ignore-errors +import contextlib import itertools -from collections.abc import KeysView, Sequence -from contextlib import contextmanager, nullcontext -from functools import partial, wraps -from typing import Any, Callable, NewType, Optional, Protocol, TypeVar +from contextlib import nullcontext +from functools import wraps +from typing import Any, Callable, Optional from unittest.mock import patch import torch @@ -24,15 +24,10 @@ ) from torch._guards import detect_fake_mode from torch._inductor.cudagraph_utils import BoxedDeviceIndex -from torch._inductor.output_code import OutputCode -from torch._inductor.utils import BoxedBool, InputType +from torch._inductor.utils import BoxedBool from torch._subclasses import FakeTensor, FakeTensorMode -from torch.fx.experimental.proxy_tensor import ( - _pytree_subclasses_that_lose_info, - make_fx, -) +from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ShapeEnv -from torch.utils._python_dispatch import is_traceable_wrapper_subclass static_inputs_log = torch._logging.getArtifactLogger( @@ -48,6 +43,12 @@ from ._aot_autograd.collect_metadata_analysis import ( # noqa: F401 run_functionalized_fw_and_collect_metadata, ) +from ._aot_autograd.frontend_utils import ( + _detect_attribute_assignment, + _try_get_metadata_from_dynamo, + construct_fake_mode, + process_inputs, +) from ._aot_autograd.functional_utils import ( # noqa: F401 _check_if_mutation_can_be_in_graph, are_all_mutations_hidden_from_autograd, @@ -61,17 +62,26 @@ sync_functional_tensor, to_fun, ) +from ._aot_autograd.graph_capture_wrappers import ( # noqa: F401 + aot_dispatch_subclass, + create_functional_call, + create_functionalized_fn, + create_functionalized_rng_ops_wrapper, + create_joint, + fn_input_mutations_to_outputs, + fn_prepped_for_autograd, +) +from ._aot_autograd.graph_compile import ( # noqa: F401 + aot_stage1_graph_capture, + aot_stage2_compile, + aot_stage2_export, +) from ._aot_autograd.input_output_analysis import ( # noqa: F401 compute_overlapping_inputs, create_graph_signature, create_synthetic_base_metadata, remove_dupe_metadata, ) -from ._aot_autograd.jit_compile_runtime_wrappers import ( # noqa: F401 - aot_dispatch_autograd, - aot_dispatch_base, - aot_dispatch_export, -) from ._aot_autograd.logging_utils import ( # noqa: F401 callback_set, describe_input, @@ -92,7 +102,10 @@ ) from ._aot_autograd.schemas import ( # noqa: F401 AOTConfig, + AOTDispatchCompiler, + AOTState, BackwardSignature, + FakifiedFlatArgs, FQN, GraphInputName, GraphOutputName, @@ -101,6 +114,7 @@ MutationType, OutputAliasInfo, OutputType, + SerializableAOTDispatchCompiler, SubclassCreationMeta, SubclassMeta, TensorAlias, @@ -113,15 +127,6 @@ wrap_tensor_subclasses, wrap_tensor_subclasses_maybe_joint, ) -from ._aot_autograd.traced_function_transforms import ( # noqa: F401 - aot_dispatch_subclass, - create_functional_call, - create_functionalized_fn, - create_functionalized_rng_ops_wrapper, - create_joint, - fn_input_mutations_to_outputs, - fn_prepped_for_autograd, -) from ._aot_autograd.utils import ( # noqa: F401 _get_autocast_states, _get_symint_hints, @@ -439,151 +444,15 @@ aot_autograd_decompositions = {} -FakifiedFlatArgs = NewType("FakifiedFlatArgs", list[Any]) - - -TOutputCode = TypeVar("TOutputCode", bound=OutputCode) - - -class AOTDispatchCompiler(Protocol): - """ - Represents a fw or bw_compiler passed to AOTAutograd. - """ - - def __call__( - self, - gm: torch.fx.GraphModule, - example_inputs: Sequence[InputType], - ) -> Any: ... - - -# TODO: bikeshed on this name -class SerializableAOTDispatchCompiler(AOTDispatchCompiler): - """ - Represents an AOTDispatchCompiler that returns an OutputCode, and is - therefore cacheable. SerializableAOTDispatchCompiler always return an OutputCode. - A _CompileFxCallable usually gets converted into an AOTDispatchCompiler after binding all of - the kwargs in _CompileFxKwargs. - """ - - def __init__( - self, - output_code_ty: type[TOutputCode], - compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode], - ): - self.output_code_ty = output_code_ty - self.compiler_fn = compiler_fn - - def __call__( - self, - gm: torch.fx.GraphModule, - example_inputs: Sequence[InputType], - ) -> OutputCode: - return self.compiler_fn(gm, example_inputs) - - -def process_inputs( - flat_args: list[Any], - aot_config: AOTConfig, - fake_mode: FakeTensorMode, - shape_env: Optional[ShapeEnv], - ignore_shape_env: bool = False, -) -> FakifiedFlatArgs: - with fake_mode: - - def convert(idx, x): - if shape_env is not None and not ignore_shape_env: - from torch._dynamo.source import ConstantSource - - if isinstance(x, int): - # We always specialize on scalar values in export. - if aot_config.is_export: - return x - source = ConstantSource(f"sym_{idx}") - return shape_env.create_symintnode( - shape_env.create_symbol(x, source), hint=x, source=source - ) - if isinstance(x, torch.ScriptObject): - return torch._library.fake_class_registry.maybe_to_fake_obj( - fake_mode, x - ) - if not isinstance(x, torch.Tensor): - return x - if isinstance(x, FakeTensor): - assert x.fake_mode is fake_mode - return x - if is_traceable_wrapper_subclass(x): - attrs, _ = x.__tensor_flatten__() - if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs): - assert all( - getattr(x, attr).fake_mode is fake_mode for attr in attrs - ) - return x - - # see note [Tensor Fakification and Symbol Caching] - symbolic_context = None - source = None - trace = True - if tracing_context := torch._guards.TracingContext.try_get(): - if x in tracing_context.tensor_to_context: - symbolic_context = tracing_context.tensor_to_context[x] - source = symbolic_context.tensor_source - # We already fakeified this tensor in Dynamo, don't - # dump the trace for it again - trace = False - if ( - idx < aot_config.num_params_buffers - and config.static_weight_shapes - and not symbolic_context - ): - # TODO: Ensure that this codepath is never exercised from - # Dynamo - return fake_mode.from_tensor(x, static_shapes=True) - - result = fake_mode.from_tensor( - x, - static_shapes=ignore_shape_env, - symbolic_context=symbolic_context, - source=source, - trace=trace, - ) - return result - - return FakifiedFlatArgs([convert(idx, x) for idx, x in enumerate(flat_args)]) - -def construct_fake_mode( - flat_args: list[Any], aot_config: AOTConfig -) -> tuple[FakeTensorMode, Optional[ShapeEnv]]: - fake_mode = detect_fake_mode(flat_args) - if fake_mode is None: - shape_env = ShapeEnv() if aot_config.dynamic_shapes else None - fake_mode = FakeTensorMode(shape_env=shape_env) - else: - shape_env = fake_mode.shape_env - return (fake_mode, shape_env) - - -def create_aot_dispatcher_function( +def create_aot_state( + stack: contextlib.ExitStack, flat_fn, fake_flat_args: FakifiedFlatArgs, aot_config: AOTConfig, fake_mode: FakeTensorMode, shape_env: Optional[ShapeEnv], -) -> tuple[Callable, ViewAndMutationMeta]: - with dynamo_timed("create_aot_dispatcher_function", log_pt2_compile_event=True): - return _create_aot_dispatcher_function( - flat_fn, fake_flat_args, aot_config, fake_mode, shape_env - ) - - -def _create_aot_dispatcher_function( - flat_fn, - fake_flat_args: FakifiedFlatArgs, - aot_config: AOTConfig, - fake_mode: FakeTensorMode, - shape_env: Optional[ShapeEnv], -) -> tuple[Callable, ViewAndMutationMeta]: +) -> AOTState: """ Traces the forward and backward graphs of the attr:`flat_fn` to generate a joint graph. The joint graph is an Fx graph with Aten ops. Please refer to @@ -600,12 +469,14 @@ def _create_aot_dispatcher_function( inputs in flat_args are parameters and buffers, and the rest are inputs. We use this to assume that parameters/buffer's shapes don't change. - - Note: this function is used both by aot_function and aot_export (controlled by aot_config.is_export) - When aot_config.is_export is True, we return an FX graph + metadata - When aot_config.is_export is False, we return an ordinary runtime function """ + # Old name for now to avoid messing with stats. Also, note this is pushed + # on the stack, so it extends BEYOND this function + stack.enter_context( + dynamo_timed("create_aot_dispatcher_function", log_pt2_compile_event=True) + ) + # This is the main entry point. # TODO: Chillee argues that dynamo itself should pass in fake tensors to # the list of arguments when compiling; at the moment we do not do this @@ -636,210 +507,184 @@ def _create_aot_dispatcher_function( # If any saved tensor hooks are active, we **don't** want to trace them. # Instead, we'll let them run at runtime, around the custom autograd.Function # that we generate in torch.compile. - with ( - torch.autograd.set_multithreading_enabled(False), - preserve_rng_state(), - fake_mode, - python_dispatcher_mode, - PhiloxStateTracker(), - torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(), - ): - from torch._library.fake_class_registry import ( - FakeScriptObject, - maybe_to_fake_obj, - ) + stack.enter_context(torch.autograd.set_multithreading_enabled(False)) + stack.enter_context(preserve_rng_state()) + stack.enter_context(fake_mode) + stack.enter_context(python_dispatcher_mode) + stack.enter_context(PhiloxStateTracker()) + stack.enter_context( + torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing() + ) - # Tracing may mutate the states the fake script object, - # so we need to duplicate the fake script objects so that subsequent tracing - # won't be affected. - def _dup_fake_script_obj(fake_flat_args): - return [ - maybe_to_fake_obj(detect_fake_mode(fake_flat_args), arg.real_obj) - if isinstance(arg, FakeScriptObject) - else arg - for arg in fake_flat_args - ] + from torch._library.fake_class_registry import FakeScriptObject, maybe_to_fake_obj + + # Tracing may mutate the states the fake script object, + # so we need to duplicate the fake script objects so that subsequent tracing + # won't be affected. + def _dup_fake_script_obj(fake_flat_args): + return [ + maybe_to_fake_obj(detect_fake_mode(fake_flat_args), arg.real_obj) + if isinstance(arg, FakeScriptObject) + else arg + for arg in fake_flat_args + ] + + needs_autograd = any( + x.requires_grad for x in fake_flat_args if isinstance(x, Tensor) + ) - needs_autograd = any( - x.requires_grad for x in fake_flat_args if isinstance(x, Tensor) - ) + with enable_python_dispatcher(): + # Patch set_rng_state as set_rng_state with fake tensors is + # nonsensical. This does not affect the collection of metadata. + with patch("torch.cuda.set_rng_state", lambda *args: None): + mod = root_module_when_exporting_non_strict(flat_fn) + if mod is not None: + ctx = _detect_attribute_assignment(mod) + else: + ctx = nullcontext() - with enable_python_dispatcher(): - # Patch set_rng_state as set_rng_state with fake tensors is - # nonsensical. This does not affect the collection of metadata. - with patch("torch.cuda.set_rng_state", lambda *args: None): - mod = root_module_when_exporting_non_strict(flat_fn) - if mod is not None: - ctx = _detect_attribute_assignment(mod) - else: - ctx = nullcontext() + if torch._functorch.config.fake_tensor_propagate_real_tensors: + # Running dynamo_timed causes fake tensor issues when + # propagate real tensor is switched on. + dynamo_timed_ctx = nullcontext() + else: + dynamo_timed_ctx = dynamo_timed( + "aot_collect_metadata", log_pt2_compile_event=True + ) - if torch._functorch.config.fake_tensor_propagate_real_tensors: - # Running dynamo_timed causes fake tensor issues when - # propagate real tensor is switched on. - dynamo_timed_ctx = nullcontext() - else: - dynamo_timed_ctx = dynamo_timed( - "aot_collect_metadata", log_pt2_compile_event=True - ) + with dynamo_timed_ctx, ctx: + fw_metadata = run_functionalized_fw_and_collect_metadata( + flat_fn, + static_input_indices=aot_config.static_input_indices, + keep_input_mutations=aot_config.keep_inference_input_mutations, + is_train=needs_autograd, + pre_dispatch=aot_config.pre_dispatch, + is_export=aot_config.is_export, + )(*_dup_fake_script_obj(fake_flat_args)) + + req_subclass_dispatch = requires_subclass_dispatch( + fake_flat_args, fw_metadata + ) + CompileEventLogger.try_add_pt2_compile( + "backend_compile", requires_subclass_dispatch=req_subclass_dispatch + ) + + output_and_mutation_safe = not any( + x.requires_grad + # view-type operations preserve requires_grad even in no_grad. + # Do not count aliases of inputs with requires_grad as reason to make a training graph, + # as AOTAutograd will perform view-replay to regenerate the view outputs at runtime, + # setting their grad_fn properly. + and not ( + x.output_type in (OutputType.alias_of_input, OutputType.is_input) + and fw_metadata.input_info[x.base_idx].requires_grad + ) + for x in fw_metadata.output_info + ) and not any( + x.requires_grad + and x.mutates_data + and not x.mutations_under_no_grad_or_inference_mode + and not x.mutations_hidden_from_autograd + for x in fw_metadata.input_info + ) - with dynamo_timed_ctx, ctx: + if needs_autograd and output_and_mutation_safe: + # We realized that none of the outputs require grad, + # and none of the inputs that require grad are mutated. + # so we actually have an inference graph. + needs_autograd = False + # A bit silly: right now in the subclass codepath, our ViewAndMutationMeta + # changes depending on whether we pass in is_train / keep_input_mutations, + # so we're forced to recompute the metadata. + # TODO: refactor the subclass path of run_functionalized_fw_and_collect_metadata + # so that this is unnecessary. + if req_subclass_dispatch: fw_metadata = run_functionalized_fw_and_collect_metadata( flat_fn, - static_input_indices=aot_config.static_input_indices, keep_input_mutations=aot_config.keep_inference_input_mutations, - is_train=needs_autograd, + is_train=False, pre_dispatch=aot_config.pre_dispatch, - is_export=aot_config.is_export, - )(*_dup_fake_script_obj(fake_flat_args)) - - req_subclass_dispatch = requires_subclass_dispatch( - fake_flat_args, fw_metadata - ) - CompileEventLogger.try_add_pt2_compile( - "backend_compile", requires_subclass_dispatch=req_subclass_dispatch - ) - - output_and_mutation_safe = not any( - x.requires_grad - # view-type operations preserve requires_grad even in no_grad. - # Do not count aliases of inputs with requires_grad as reason to make a training graph, - # as AOTAutograd will perform view-replay to regenerate the view outputs at runtime, - # setting their grad_fn properly. - and not ( - x.output_type - in (OutputType.alias_of_input, OutputType.is_input) - and fw_metadata.input_info[x.base_idx].requires_grad + static_input_indices=aot_config.static_input_indices, + )(*fake_flat_args) + else: + fw_metadata = ViewAndMutationMeta( + input_info=fw_metadata.input_info, + output_info=fw_metadata.output_info, + num_intermediate_bases=fw_metadata.num_intermediate_bases, + keep_input_mutations=aot_config.keep_inference_input_mutations, + traced_tangents=fw_metadata.traced_tangents, + subclass_inp_meta=fw_metadata.subclass_inp_meta, + subclass_fw_graph_out_meta=fw_metadata.subclass_fw_graph_out_meta, + subclass_tangent_meta=fw_metadata.subclass_tangent_meta, + is_train=False, + tokens=fw_metadata.tokens, + static_input_indices=fw_metadata.static_input_indices, ) - for x in fw_metadata.output_info - ) and not any( - x.requires_grad - and x.mutates_data - and not x.mutations_under_no_grad_or_inference_mode - and not x.mutations_hidden_from_autograd - for x in fw_metadata.input_info - ) - if needs_autograd and output_and_mutation_safe: - # We realized that none of the outputs require grad, - # and none of the inputs that require grad are mutated. - # so we actually have an inference graph. - needs_autograd = False - # A bit silly: right now in the subclass codepath, our ViewAndMutationMeta - # changes depending on whether we pass in is_train / keep_input_mutations, - # so we're forced to recompute the metadata. - # TODO: refactor the subclass path of run_functionalized_fw_and_collect_metadata - # so that this is unnecessary. - if req_subclass_dispatch: - fw_metadata = run_functionalized_fw_and_collect_metadata( - flat_fn, - keep_input_mutations=aot_config.keep_inference_input_mutations, - is_train=False, - pre_dispatch=aot_config.pre_dispatch, - static_input_indices=aot_config.static_input_indices, - )(*fake_flat_args) - else: - fw_metadata = ViewAndMutationMeta( - input_info=fw_metadata.input_info, - output_info=fw_metadata.output_info, - num_intermediate_bases=fw_metadata.num_intermediate_bases, - keep_input_mutations=aot_config.keep_inference_input_mutations, - traced_tangents=fw_metadata.traced_tangents, - subclass_inp_meta=fw_metadata.subclass_inp_meta, - subclass_fw_graph_out_meta=fw_metadata.subclass_fw_graph_out_meta, - subclass_tangent_meta=fw_metadata.subclass_tangent_meta, - is_train=False, - tokens=fw_metadata.tokens, - static_input_indices=fw_metadata.static_input_indices, - ) - - if fw_metadata.num_intermediate_bases > 0: - assert not req_subclass_dispatch, f"""\ + if fw_metadata.num_intermediate_bases > 0: + assert not req_subclass_dispatch, f"""\ torch.compile is currently being used with tensor subclass inputs: {",".join([str(type(x)) for x in fake_flat_args])}. We are attempting to a compile a graph with two graph outputs that alias one another, which is currently unsupported in the subclass use case. If you run into this, please file a github issue""" - if aot_config.is_export: - # aot_export: ban input metadata mutations for now to keep shared code paths simpler. - # Keeping .resize_() in the graph will require some work - # Allowing it but keeping the graph functional will require some calling convention changes. - if len([x for x in fw_metadata.input_info if x.mutates_metadata]) != 0: - raise RuntimeError( - f"""\ + if aot_config.is_export: + # aot_export: ban input metadata mutations for now to keep shared code paths simpler. + # Keeping .resize_() in the graph will require some work + # Allowing it but keeping the graph functional will require some calling convention changes. + if len([x for x in fw_metadata.input_info if x.mutates_metadata]) != 0: + raise RuntimeError( + f"""\ Found an input that received a metadata mutation, through e.g. a call to `.resize_()` or `.transpose_()`. This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue. fw_metadata={str(fw_metadata)}""" - ) - # In export, banning data mutations on inputs that require grad for now. - # This should be rare, and is tricky to get right. When we trace the backward, - # we currently trace with autograd.grad instead of .backward(), which makes it difficult - # to ensure that we run autograd all the way through the input **before** it saw the mutation. - if ( - len( - [ - x - for x in fw_metadata.input_info - if x.requires_grad and x.mutates_data - ] - ) - != 0 - ): - raise RuntimeError( - f"""\ + ) + # In export, banning data mutations on inputs that require grad for now. + # This should be rare, and is tricky to get right. When we trace the backward, + # we currently trace with autograd.grad instead of .backward(), which makes it difficult + # to ensure that we run autograd all the way through the input **before** it saw the mutation. + if ( + len( + [ + x + for x in fw_metadata.input_info + if x.requires_grad and x.mutates_data + ] + ) + != 0 + ): + raise RuntimeError( + f"""\ Found a graph input that requires gradients, and received a mutation. This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue. fw_metadata={str(fw_metadata)}""" - ) - if req_subclass_dispatch: - raise RuntimeError( - """\ + ) + if req_subclass_dispatch: + raise RuntimeError( + """\ aot_export is not currently supported with traceable tensor subclass. If you need this feature, please comment on """ - ) + ) - # Need to decide on a strategy for functionalized RNG: toggling via global config seems bad, - # and turning it on will require a non-trivial calling convention change for any export runtime. - if config.functionalize_rng_ops: - raise RuntimeError( - """\ + # Need to decide on a strategy for functionalized RNG: toggling via global config seems bad, + # and turning it on will require a non-trivial calling convention change for any export runtime. + if config.functionalize_rng_ops: + raise RuntimeError( + """\ Functionalized RNG is not currently supported in the aot_export workflow. Please file a github issue, or otherwise set torch._functorch.config.functionalize_rng_ops = False.""" - ) - - def choose_dispatcher(needs_autograd, aot_config): - """ - Pick a dispatcher based on the config rules. - """ - if aot_config.is_export: - # export uses just the "graph bits", whereas the other - # two dispatchers include some extra work around handling a runtime epilogue - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="export" - ) - return partial(aot_dispatch_export, needs_autograd=needs_autograd) - elif needs_autograd and not aot_config.pre_dispatch: - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="autograd" - ) - return aot_dispatch_autograd - else: - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="inference" - ) - return aot_dispatch_base - - compiler_fn = choose_dispatcher(needs_autograd, aot_config) + ) - compiled_fn, fw_metadata = compiler_fn( - flat_fn, - _dup_fake_script_obj(fake_flat_args), - aot_config, - fw_metadata=fw_metadata, - ) - return compiled_fn, fw_metadata + return AOTState( + needs_autograd=needs_autograd, + flat_args=_dup_fake_script_obj(fake_flat_args), + fw_metadata=fw_metadata, + # Packaging this just for later use + aot_config=aot_config, + stack=stack, + ) def aot_function( @@ -942,13 +787,12 @@ def returned_function(*args, **kwargs): fake_flat_args: FakifiedFlatArgs = process_inputs( flat_args, aot_config, fake_mode, shape_env ) - compiled_fn, _ = create_aot_dispatcher_function( - flat_fn, - fake_flat_args, - aot_config, - fake_mode, - shape_env, - ) + with contextlib.ExitStack() as stack: + aot_state = create_aot_state( + stack, flat_fn, fake_flat_args, aot_config, fake_mode, shape_env + ) + aot_graph_capture = aot_stage1_graph_capture(aot_state, flat_fn) + compiled_fn, _ = aot_stage2_compile(aot_state, aot_graph_capture) cached_res = (compiled_fn, out_spec) cached_fn, out_spec = cached_res @@ -1010,110 +854,20 @@ def forward(self, *args, **kwargs): return AOTModule() -def _try_get_metadata_from_dynamo( - mod: torch.nn.Module, param_keys: KeysView[str], full_args_num: int -) -> tuple[Optional[list[torch._guards.Source]], list[int]]: - """ - Metadata is forwarded from Dynamo to AOTDispatch via special fields on GraphModule. - We first verify that `mod` does come from Dynamo, then we handle cases where - metadata might be missing. - - Returns: - aot_autograd_arg_pos_to_source: used to dedup params and their guards - static_input_indices: used to identify static inputs for cudagraphs - """ - # Note [Assumption on Dynamo Metadata] - # This function assumes a graph module from dynamo provides `dynamo_compiled_id`, - # _param_name_to_source, and every placeholder node has `_dynamo_source` attributes. - # When gm is modified (e.g., DDPOptimizer via split_module), metadata needs to - # be propagated in order to be recognized as a dynamo graph - - if not (isinstance(mod, torch.fx.GraphModule) and "dynamo_compile_id" in mod.meta): - # graph was not captured by dynamo - return None, [] - - if not hasattr(mod, "_param_name_to_source"): - # is from export - return None, [] - - # We now know this came from dynamo, and (1) we care about guards, - # so setting up aot_autograd_arg_pos_to_source for downstream dedup guards - # can now be done safely. (2) Dynamo logic protects the 1:1 sizing below. - # Additionally, we mark static indices for cudagraphs. - param_name_to_source = mod._param_name_to_source - seen_sources = set() - - aot_autograd_arg_pos_to_source = [] - static_input_indices = [] - # Collect the new inputs lifted by aotdispatch - for i, name in enumerate(param_keys): - assert name in param_name_to_source, f"{name} not found." - source = param_name_to_source[name] - assert source not in seen_sources, source - seen_sources.add(source) - aot_autograd_arg_pos_to_source.append(source) - - static_input_indices.append(i) - - # Collect the dynamo graph inputs - # TODO(mlazos): Revisit if this is still needed. With Dynamo install ID - # matched tensors back into the Fx graph, this might not be necessary. - for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")): - assert hasattr(node, "_dynamo_source") - source = node._dynamo_source - # `source`` specifies the source from user code. ddp optimizer may have - # intermediate values becoming submodule placeholders which does not - # have a source - assert source is None or source not in seen_sources, source - seen_sources.add(source) - aot_autograd_arg_pos_to_source.append(source) - source_name = source.name() if source else str(source) - - # input[i] in dynamo is now: - # input[i + len(extra_params)] in AOT, - # where extra_params are the params/buffers that dynamo baked into the - # OutputGraph - actual_pos = pos + len(param_keys) - - if "tensor_dict" in node.meta and node.meta["tensor_dict"].get( - "_dynamo_static_input_type", None - ): - static_inputs_log.debug( - "Adding static input pos %s for source %s", actual_pos, source_name - ) - static_input_indices.append(actual_pos) - else: - static_inputs_log.debug( - "Non-static input pos %s for source %s", actual_pos, source_name - ) - - assert full_args_num == len(aot_autograd_arg_pos_to_source) - return aot_autograd_arg_pos_to_source, static_input_indices - - -def aot_module_simplified( +def prepare_aot_module_simplified( mod: nn.Module, args, fw_compiler: AOTDispatchCompiler, - bw_compiler: Optional[AOTDispatchCompiler] = None, - partition_fn: Callable = default_partition, - decompositions: Optional[dict] = None, - keep_inference_input_mutations=False, - inference_compiler: Optional[AOTDispatchCompiler] = None, - cudagraphs: Optional[BoxedBool] = None, - boxed_forward_device_index: Optional[BoxedDeviceIndex] = None, - ignore_shape_env: bool = False, -) -> nn.Module: - """ - This is the simplified or low overhead version of aot_module. For frontends - like TorchDynamo, the input functions/modules to AOT are static and have - unpacked inputs/outputs. This gives us an opportunity to remove the - (1) pytree overhead to parse inputs/outputs, - (2) AOT Autograd cache, - (3) Reading of params/buffers in every forward call - - :func:`aot_module_simplified` removes these overheads. - """ + bw_compiler: AOTDispatchCompiler, + partition_fn: Callable, + decompositions: dict, + keep_inference_input_mutations, + inference_compiler: AOTDispatchCompiler, + boxed_forward_device_index: BoxedDeviceIndex, + ignore_shape_env: bool, +): + # TODO: There's something a bit suspicious here; typically simplified + # module shouldn't actually have any parameters... params = { **dict(mod.named_parameters(remove_duplicate=False)), **dict(mod.named_buffers(remove_duplicate=False)), @@ -1122,14 +876,6 @@ def aot_module_simplified( params_flat = list(params_flat) params_len = len(params_flat) - if cudagraphs is None: - cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs) - - if bw_compiler is None: - bw_compiler = fw_compiler - if inference_compiler is None: - inference_compiler = fw_compiler - full_args = [] # First, the params full_args.extend(params_flat) @@ -1177,68 +923,125 @@ def aot_module_simplified( fake_flat_args = process_inputs( full_args, aot_config, fake_mode, shape_env, ignore_shape_env ) + functional_call = create_functional_call(mod, params_spec, params_len) + + return ( + functional_call, + params_flat, + fake_flat_args, + aot_config, + fake_mode, + shape_env, + ) + - def dispatch_and_compile(): - functional_call = create_functional_call(mod, params_spec, params_len) - with compiled_autograd._disable(): - compiled_fn, _ = create_aot_dispatcher_function( +def aot_module_simplified( + mod: nn.Module, + args, + fw_compiler: AOTDispatchCompiler, + bw_compiler: Optional[AOTDispatchCompiler] = None, + partition_fn: Callable = default_partition, + decompositions: Optional[dict] = None, + keep_inference_input_mutations=False, + inference_compiler: Optional[AOTDispatchCompiler] = None, + # TODO: This doesn't seem to be used in any nontrivial way, check if it's + # actually needed + cudagraphs: Optional[BoxedBool] = None, + boxed_forward_device_index: Optional[BoxedDeviceIndex] = None, + ignore_shape_env: bool = False, +) -> nn.Module: + """ + This is the simplified or low overhead version of aot_module. For frontends + like TorchDynamo, the input functions/modules to AOT are static and have + unpacked inputs/outputs. This gives us an opportunity to remove the + (1) pytree overhead to parse inputs/outputs, + (2) AOT Autograd cache, + (3) Reading of params/buffers in every forward call + + :func:`aot_module_simplified` removes these overheads. + """ + + if cudagraphs is None: + cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs) + if bw_compiler is None: + bw_compiler = fw_compiler + if inference_compiler is None: + inference_compiler = fw_compiler + + with contextlib.ExitStack() as stack: + ( + functional_call, + params_flat, + fake_flat_args, + aot_config, + fake_mode, + shape_env, + ) = prepare_aot_module_simplified( + mod, + args, + fw_compiler, + bw_compiler, + partition_fn, + decompositions, + keep_inference_input_mutations, + inference_compiler, + boxed_forward_device_index, + ignore_shape_env, + ) + + compiled_fn = None + + if isinstance(fw_compiler, SerializableAOTDispatchCompiler): + local = should_use_local_autograd_cache() + remote = should_use_remote_autograd_cache() + if local or remote: + set_feature_use("aot_autograd_remote_cache", remote) + compiled_fn = AOTAutogradCache.try_load( + mod, + fake_flat_args, + aot_config, + cudagraphs, + boxed_forward_device_index, + local, + remote, + ) + + if compiled_fn is None: + stack.enter_context(compiled_autograd._disable()) + aot_state = create_aot_state( + stack, functional_call, fake_flat_args, aot_config, fake_mode, shape_env, ) - return compiled_fn - - # We only care if the forward will return an OutputCode. - if isinstance(fw_compiler, SerializableAOTDispatchCompiler): - local = should_use_local_autograd_cache() - remote = should_use_remote_autograd_cache() - if local or remote: - set_feature_use("aot_autograd_remote_cache", remote) - compiled_fn = AOTAutogradCache.load( - dispatch_and_compile, - mod, - fake_flat_args, - aot_config, - cudagraphs, - boxed_forward_device_index, - local, - remote, - ) - else: - compiled_fn = dispatch_and_compile() - else: - compiled_fn = dispatch_and_compile() + aot_graph_capture = aot_stage1_graph_capture(aot_state, functional_call) + compiled_fn, _ = aot_stage2_compile(aot_state, aot_graph_capture) if isinstance(mod, torch._dynamo.utils.GmWrapper): # This function is called by the flatten_graph_inputs wrapper, which boxes # the inputs so that they can be freed before the end of this scope. # For overhead reasons, this is not the default wrapper, see comment: # https://github.com/pytorch/pytorch/pull/122535/files#r1560096481 - def boxed_forward(runtime_args: list[Any]): + def forward(runtime_args: list[Any]): flat_args = [] flat_args.extend(params_flat) flat_args.extend(runtime_args) runtime_args.clear() return compiled_fn(flat_args) - # Just for convenience - boxed_forward.zero_grad = mod.zero_grad - boxed_forward.named_parameters = mod.named_parameters - boxed_forward.named_buffers = mod.named_buffers - return boxed_forward - - # TODO: There is something deeply wrong here; compiled_fn running with - # the boxed calling convention, but aot_module_simplified somehow - # historically returned a function that was not the boxed calling - # convention. This should get fixed... - # NB: GraphModule/nn.Module rely on the non-boxed calling convention here - def forward(*runtime_args: tuple[Any]): - full_args = [] - full_args.extend(params_flat) - full_args.extend(runtime_args) - return compiled_fn(full_args) + else: + # TODO: There is something deeply wrong here; compiled_fn running with + # the boxed calling convention, but aot_module_simplified somehow + # historically returned a function that was not the boxed calling + # convention. This should get fixed... + # NB: GraphModule/nn.Module rely on the non-boxed calling convention here + def forward(*runtime_args: tuple[Any]): + full_args = [] + full_args.extend(params_flat) + full_args.extend(runtime_args) + return compiled_fn(full_args) # Just for convenience forward.zero_grad = mod.zero_grad @@ -1394,6 +1197,8 @@ def fn_to_trace(*args): dynamic_shapes=dynamic_shapes, kwargs=kwargs, ) + + # TODO: subsume this path with the aot_stage2_graph_capture path if trace_joint: @wraps(functional_call) @@ -1624,114 +1429,19 @@ def _aot_export_function( shape_env = fake_mode.shape_env fake_flat_args = process_inputs(flat_args, aot_config, fake_mode, shape_env) - fx_g, meta = create_aot_dispatcher_function( - flat_fn, - fake_flat_args, - aot_config, - fake_mode, - shape_env, - ) - return fx_g, meta, in_spec, out_spec.spec - - -@contextmanager -def _detect_attribute_assignment(mod: torch.nn.Module): - # Do not allow assignment of tensor attributes during export unless - # the attribute is registered as a buffer. - - NN_MODULE_STD_ATTRS = [ - "_backward_hooks", - "_backward_pre_hooks", - "_buffers", - "_forward_hooks", - "_forward_hooks_always_called", - "_forward_hooks_with_kwargs", - "_forward_pre_hooks", - "_forward_pre_hooks_with_kwargs", - "_is_full_backward_hook", - "_load_state_dict_post_hooks", - "_load_state_dict_pre_hooks", - "_modules", - "_non_persistent_buffers_set", - "_parameters", - "_state_dict_hooks", - "_state_dict_pre_hooks", - "training", - ] - NN_MODULE_LAZY_STD_ATTRS = [ - "_initialize_hook", - "_load_hook", - ] - STD_ATTRS = { - *NN_MODULE_STD_ATTRS, - *NN_MODULE_LAZY_STD_ATTRS, - } - - def _get_attributes(mod): - # return any attributes of a module that are not standard attributes - return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS} - - # save state of attributes before enter - snapshot = pytree.tree_map( - lambda x: x, - _get_attributes(mod), - is_leaf=lambda x: type(x) in _pytree_subclasses_that_lose_info, - ) - try: - yield - finally: - # after exit, compare state of attributes with snapshot - # to detect which tensor attributes were assigned - assigned_tensor_attributes = [] - - def _collect_assigned_tensor_attributes(kp, v, _v): - if _v is not v: - attr, *rest = kp - if isinstance(v, torch.Tensor): - assigned_tensor_attributes.append( - f"self.{attr.key}{pytree.keystr(rest)}" - ) - # TODO(avik): Assigning all other types are allowed right now. - # Maybe in the future we want to limit this to primitive types? - return v - - new_attrs = _get_attributes(mod) - if len(new_attrs) != len(snapshot): - added_attrs = new_attrs.keys() - snapshot.keys() - deleted_attrs = snapshot.keys() - new_attrs.keys() - - if len(added_attrs) > 0: - raise ValueError( - f"During torch.export, following attrs were created in the model.forward: {added_attrs} " - f"Such attributes must be registered as buffers using the `register_buffer` " - f"API and must be initialized at model.__init__ " - f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." - ) - - if len(deleted_attrs) > 0: - raise ValueError( - f"During torch.export, following attrs were deleted in the model.forward: {deleted_attrs} " - f"Such attributes must be registered as buffers using the `register_buffer` " - f"API and must be initialized at model.__init__ " - f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." - ) - - pytree.tree_map_with_path( - _collect_assigned_tensor_attributes, snapshot, new_attrs + with contextlib.ExitStack() as stack: + aot_state = create_aot_state( + stack, + flat_fn, + fake_flat_args, + aot_config, + fake_mode, + shape_env, ) - # restore state of all attributes (including, e.g., of primitive types) - mod.__dict__.update(snapshot) + aot_graph_capture = aot_stage1_graph_capture(aot_state, flat_fn) + fx_g, meta = aot_stage2_export(aot_state, aot_graph_capture) - if assigned_tensor_attributes: - if len(assigned_tensor_attributes) > 1: - noun, verb = "attributes", "were" - else: - noun, verb = "attribute", "was" - raise ValueError( - f"The tensor {noun} {', '.join(assigned_tensor_attributes)} {verb} assigned during export. " - "Such attributes must be registered as buffers using the `register_buffer` API " - "(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)." - ) + return fx_g, meta, in_spec, out_spec.spec compiled_function = aot_function diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index e8778f31889d..2833a2b1631a 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -61,6 +61,10 @@ # need to add env vars or make it configurable bundled_autograd_cache: bool = False +# Whether or not to normalize placeholder names in graphs +# from dynaom in AOTAutogradCache +autograd_cache_normalize_inputs = not is_fbcode() + def remote_autograd_cache_default() -> Optional[bool]: if os.environ.get("TORCHINDUCTOR_AUTOGRAD_REMOTE_CACHE") == "1": diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 7b36092e09eb..c666a924b468 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -684,47 +684,11 @@ def quantize_activation_bw(graph: torch.fx.Graph) -> None: counters["inductor"]["activation_quantization_bwd_aten_pass"] += 1 -def enable_activation_quantization( - saved_values: list[fx.Node], +def perform_fp8_activation_quantization( fwd_module: fx.GraphModule, bwd_module: fx.GraphModule, - static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None, + bwd_module_inputs: dict[str, fx.Node], ) -> None: - if ( - inductor_config.post_grad_fusion_options.get( - "activation_quantization_aten_pass", None - ) - is None - ): - return - - static_input_names = ( - [node.name for node in static_lifetime_input_nodes] - if static_lifetime_input_nodes - else [] - ) - saved_values_names = {node.name: node for node in saved_values} - if torch._inductor.config.post_grad_fusion_options[ - "activation_quantization_aten_pass" - ].get("exclude_primals", False): - saved_values_names = { - node.name: node for node in saved_values if "primals" not in node.name - } - fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0] - bwd_module_inputs = { - node.name: node for node in bwd_module.graph.find_nodes(op="placeholder") - } - for node in fwd_module_outputs: - if node.name in saved_values_names and should_quantize(node): - if node.name in static_input_names: - log.debug("Skipping quantization of static input %s: ", node.name) - continue - node.meta["saved_for_quantization"] = True - node.meta["dequant_type"] = node.meta["val"].dtype - # some of the fwd outputs and bwd inputs are not share the same object - bwd_module_inputs[node.name].meta["saved_for_quantization"] = True - bwd_module_inputs[node.name].meta["dequant_type"] = node.meta["val"].dtype - trace_structured( "artifact", metadata_fn=lambda: { @@ -808,6 +772,53 @@ def enable_activation_quantization( ) +def enable_activation_quantization( + saved_values: list[fx.Node], + fwd_module: fx.GraphModule, + bwd_module: fx.GraphModule, + static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None, +) -> None: + if ( + inductor_config.post_grad_fusion_options.get( + "activation_quantization_aten_pass", None + ) + is None + ): + return + + static_input_names = ( + [node.name for node in static_lifetime_input_nodes] + if static_lifetime_input_nodes + else [] + ) + saved_values_names = {node.name: node for node in saved_values} + if torch._inductor.config.post_grad_fusion_options[ + "activation_quantization_aten_pass" + ].get("exclude_primals", False): + saved_values_names = { + node.name: node for node in saved_values if "primals" not in node.name + } + fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0] + bwd_module_inputs = { + node.name: node for node in bwd_module.graph.find_nodes(op="placeholder") + } + should_perform_fp8_quant = False + for node in fwd_module_outputs: + if node.name in saved_values_names and should_quantize(node): + if node.name in static_input_names: + log.debug("Skipping quantization of static input %s: ", node.name) + continue + node.meta["saved_for_quantization"] = True + node.meta["dequant_type"] = node.meta["val"].dtype + # some of the fwd outputs and bwd inputs are not share the same object + bwd_module_inputs[node.name].meta["saved_for_quantization"] = True + bwd_module_inputs[node.name].meta["dequant_type"] = node.meta["val"].dtype + should_perform_fp8_quant = True + + if should_perform_fp8_quant: + perform_fp8_activation_quantization(fwd_module, bwd_module, bwd_module_inputs) + + def _extract_fwd_bwd_modules( joint_module: fx.GraphModule, saved_values: list[fx.Node], diff --git a/torch/_guards.py b/torch/_guards.py index 9d64513d01be..bce574df3feb 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -155,17 +155,6 @@ def is_fsdp_module(self) -> bool: return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE) def is_specialized_nn_module(self) -> bool: - import torch._dynamo.config as config - - if config._unsafe_skip_fsdp_module_guards: - return ( - self - in ( - GuardSource.GLOBAL_SPECIALIZED_NN_MODULE, - GuardSource.LOCAL_SPECIALIZED_NN_MODULE, - ) - or self.is_fsdp_module() - ) return self in ( GuardSource.GLOBAL_SPECIALIZED_NN_MODULE, GuardSource.LOCAL_SPECIALIZED_NN_MODULE, diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 2000571f6057..e00036a8c14e 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -111,16 +111,22 @@ def _maybe_compile_and_run_fn(fn, *args): def reenter_make_fx(fn): + from torch._guards import detect_fake_mode from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER + from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts @functools.wraps(fn) def wrapped(*args): assert _CURRENT_MAKE_FX_TRACER is not None, ( "Cannot reenter make_fx when we're not under a make_fx tracing session" ) - return _CURRENT_MAKE_FX_TRACER.trace_subgraph( + gm = _CURRENT_MAKE_FX_TRACER.trace_subgraph( _maybe_run_with_interpreter(fn), *args ) + if (fake_mode := detect_fake_mode()) and fake_mode.shape_env is not None: + insert_deferred_runtime_asserts(gm, fake_mode.shape_env, "reenter_make_fx") + gm.recompile() + return gm return wrapped @@ -236,6 +242,7 @@ def diff_device( def _set_compilation_env(): _old_is_tracing = torch.fx._symbolic_trace._is_fx_tracing_flag _old_allow_empty_graphs = torch._dynamo.config.allow_empty_graphs + _old_capture_scalar_outputs = torch._dynamo.config.capture_scalar_outputs # The issue is tracked in https://github.com/pytorch/pytorch/issues/144360: when dynamo finds # the top-level frame produces no graph, the default behavior is to fallback to eager. # Then when it encounters an inner function, it will try to trace that function again, which is unnecessary. @@ -249,10 +256,12 @@ def _set_compilation_env(): # once we are confident fx tracing works with dynamo. torch.fx._symbolic_trace._is_fx_tracing_flag = False torch._dynamo.config.allow_empty_graphs = True + torch._dynamo.config.capture_scalar_outputs = True yield finally: torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing torch._dynamo.config.allow_empty_graphs = _old_allow_empty_graphs + torch._dynamo.config.capture_scalar_outputs = _old_capture_scalar_outputs # The invariant here is that we always trace the branch with fake tensor @@ -856,7 +865,11 @@ def _get_shape_env( # the runtime assertions for unbacked symbols. new_fake_mode = torch._subclasses.FakeTensorMode( shape_env=_get_shape_env(fake_args), - allow_non_fake_inputs=False, + # In executorch, there's an scalar_to_tensor pass that turns scalar inputs into a tensor constant + # e.g. add(a, 1) 1 is turned into a tensor, which becomes a constant tensor attribute in the graph. + # We allow non fake inputs for this purpose. This is fine for mutation detection purpose: + # inputs are all fake and all mutations/aliasing are still detected. + allow_non_fake_inputs=True, ) # We need to temporarily turn inference_mode off because # under inference mode, tensor version counter is not tracked. diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index d94ccf16d216..16f460625616 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -17,6 +17,7 @@ from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( _temp_remove_metadata_torch_function_mode, + disable_proxy_modes_tracing, ProxyTorchDispatchMode, track_tensor_tree, ) @@ -285,21 +286,44 @@ def _trace_while_loop( # iteration. Ideally, we should know that the final output is >= 0 but we didn't constrain the # unbacked symint output of subgraph as of today because this requires a smart range analysis. fake_mode: FakeTensorMode = _find_or_create_fake_mode() - unspecialized_carried_inputs = pytree.tree_map_only( - (int, torch.SymInt), - # For temporarily created unbacked symints, we don't need to bind them to any proxy - lambda _: _create_unbacked_symint( - fake_mode, ignore_fresh_unbacked_symbols=True - ), - carried_inputs, - ) - cond_graph = reenter_make_fx(cond_fn)( - *unspecialized_carried_inputs, *additional_inputs - ) - body_graph = reenter_make_fx(body_fn)( - *unspecialized_carried_inputs, *additional_inputs - ) + def _unspecialize_carried_inputs(x): + if isinstance(x, (int, torch.SymInt)): + return _create_unbacked_symint( + fake_mode, ignore_fresh_unbacked_symbols=True + ) + # Note: [unspecialize constant tensor carry] + # We need to disable constant specialization for tensor inputs that become loop carries. + # Here's the problem: when a user creates a constant tensor e.g. torch.tensor(0), PyTorch calls aten.lift_fresh_copy + # to create a safe copy (avoiding aliasing issues), which creates a FakeTensor with constant=True. + # But when this FakeTensor becomes a loop carry, we have a problem: + # - Operations like .item() will read the constant value and bake it into the traced code + # - This is incorrect because carry variables change between loop iterations + # - The traced code would use the wrong constant value for all iterations + # Solution: We clone the constant tensors and mark the cloned tensor as non-constant so they won't + # be specialized to fixed values during tracing body_fn or cond_fn. + elif isinstance(x, torch.Tensor): + x = x.clone() + if hasattr(x, "constant") and x.constant is not None: + x.constant = None + return x + + with disable_proxy_modes_tracing(): + unspecialized_carried_inputs = pytree.tree_map_only( + (int, torch.SymInt, torch.Tensor), + # For temporarily created unbacked symints, we don't need to bind them to any proxy + lambda x: _unspecialize_carried_inputs(x), + carried_inputs, + ) + + def produce_graph(fn): + cloned_carried_inputs = pytree.tree_map_only( + torch.Tensor, lambda x: x.clone(), unspecialized_carried_inputs + ) + return reenter_make_fx(fn)(*cloned_carried_inputs, *additional_inputs) + + cond_graph = produce_graph(cond_fn) + body_graph = produce_graph(body_fn) next_name = None i = 0 diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index 94762a68b343..f80b71cbe420 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -6,7 +6,6 @@ import os from typing import Any, IO, Literal, Optional, TYPE_CHECKING, Union -import torch._inductor.config import torch.fx from .standalone_compile import CompiledArtifact # noqa: TC001 @@ -15,6 +14,7 @@ if TYPE_CHECKING: from torch._inductor.utils import InputType from torch.export import ExportedProgram + from torch.export.pt2_archive._package import AOTICompiledModel from torch.export.pt2_archive._package_weights import Weights from torch.types import FileLike @@ -223,7 +223,7 @@ def _aoti_compile_and_package_inner( not_strict_accuracy = check_accuracy == "accuracy" if not same_two_models( gm, - compiled_model, + compiled_model, # type: ignore[arg-type] args, only_fwd=True, require_fp64=not_strict_accuracy, @@ -238,7 +238,7 @@ def _aoti_compile_and_package_inner( def aoti_load_package( path: FileLike, run_single_threaded: bool = False, device_index: int = -1 -) -> Any: # type: ignore[type-arg] +) -> AOTICompiledModel: """ Loads the model from the PT2 package. diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index b7bab02da5e4..689a006eb56b 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -365,6 +365,17 @@ def can_fuse( WhyNoFuse(node1, node2)("Fusion will increase peak memory") return False + if ( + config.realize_acc_reads_size_threshold is not None + and scheduler.fusion_accumulate_large_reads( + node1, + node2, + config.realize_acc_reads_size_threshold, + ) + ): + WhyNoFuse(node1, node2)("Fusion accumulate large amount of reads") + return False + return True @staticmethod diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 6ade10e63163..6ba9147c4e9b 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -52,6 +52,7 @@ from torch._dynamo.utils import CompileEventLogger, counters, dynamo_timed from torch._inductor import config, exc, metrics from torch._inductor.codegen.common import ( + custom_backend_codegen_configs, custom_backend_passes, init_backend_registration, ) @@ -854,6 +855,13 @@ def __init__( map(self._get_custom_pass_detail, custom_backend_passes.values()) ) + # Save custom inductor codegen configs + self.custom_backend_codegen_configs = { + device: custom_config.save_config_portable(ignore_private_configs=False) + for device, custom_config in custom_backend_codegen_configs.items() + if custom_config is not None + } + # This is mainly added to handle these two inductor configs, which are (unfortunately) # sometimes cache safe: # - _pre_fusion_custom_pass @@ -1635,9 +1643,6 @@ def compile( """ generated_files: list[Union[str, Weights]] = additional_files # type: ignore[assignment] - if sys.platform == "win32": - raise RuntimeError("AotCodeCompiler not yet supported for inductor") - _set_gpu_runtime_env() # cpp_extension consults the env picked_vec_isa = pick_vec_isa() @@ -1804,9 +1809,18 @@ def _compile_consts(consts: bytes, platform: str) -> str: elif platform == "darwin": section_attr = "__DATA,__data" symbol_prefix = "_" + elif platform == "win32": + symbol_prefix = "" + # ASM build is not supported on Windows, force use CPP build. + use_asm_build = False else: raise RuntimeError(f"Unsupported platform: {platform}") + # Intel compiler failed to compile this manually constructed assembly file. + # Switch XPU to use consts cpp build. + if device_type == "xpu": + use_asm_build = False + is_large_consts = len(consts) > 1024 def format_consts_to_asm( @@ -1837,6 +1851,7 @@ def format_consts_to_asm( def format_consts_to_cpp( consts: bytes, align_bytes: int, symbol_prefix: str ) -> tuple[str, str]: + consts_size = len(consts) asan_attr = """#if defined(__clang__) || defined (__GNUC__)\t\n\ #define ATTRIBUTE_NO_SANITIZE_ADDRESS __attribute__((no_sanitize("address")))\t\n\ #else\t\n\ @@ -1846,7 +1861,7 @@ def format_consts_to_cpp( ATTRIBUTE_NO_SANITIZE_ADDRESS\t\n""" const_cpp = asan_attr const_cpp += f"alignas({align_bytes}) extern " - const_cpp += f"const unsigned char {symbol_prefix}_binary_constants_bin_start[] = {{\t\n" + const_cpp += f"unsigned char {symbol_prefix}_binary_constants_bin_start[{consts_size}] = {{\t\n" count_bytes = 0 for c in consts: const_cpp += f"{c}, " @@ -1854,7 +1869,7 @@ def format_consts_to_cpp( if count_bytes % 16 == 0: const_cpp += "\t\n" const_cpp += "};\t\n" - const_cpp += f"alignas({align_bytes}) extern const unsigned char * {symbol_prefix}_binary_constants_bin_end;\t\n" + const_cpp += f"alignas({align_bytes}) extern unsigned char * {symbol_prefix}_binary_constants_bin_end;\t\n" return const_cpp, "cpp" if use_asm_build: @@ -1873,9 +1888,7 @@ def format_consts_to_cpp( ) consts_s = Path(consts_s) object_build_options = CppTorchDeviceOptions( - # Intel compiler failed to compile this manually constructed assembly file. - # it is ok to use gcc to compile the .S to a .o and linked with Intel compiler . - device_type=device_type if device_type != "xpu" else "cpu", + device_type=device_type, aot_mode=graph.aot_mode, compile_only=True, use_relative_path=use_relative_path, @@ -2163,40 +2176,44 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: cubins_o = [] asm_files = [] - ld, objcopy = get_ld_and_objcopy(use_relative_path) - for kernel_name, value in CudaKernelParamCache.cache.items(): - if asm_file := value["asm"]: - asm_files.append(asm_file) - - cubin_file = value[get_cpp_wrapper_cubin_path_name()] - if config.aot_inductor.emit_multi_arch_kernel and device_type == "cuda": - current_arch = _nvcc_arch_as_compile_option() - cmd = ( - f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file} " - # Triton only allows generating PTX version as same as the current arch - f"-gencode arch=compute_{current_arch},code=compute_{current_arch} " - # Include SASS for the current specific arch - f"-gencode arch=compute_{current_arch},code=sm_{current_arch} " - ) - try: - subprocess.run( - cmd.split(), - capture_output=True, - text=True, - check=True, - ) - except subprocess.CalledProcessError as e: - print( - f"{cmd} failed with:\nstdout:\n{e.stdout}\nstderr:\n{e.stderr}", - file=sys.stderr, + if not _IS_WINDOWS: + ld, objcopy = get_ld_and_objcopy(use_relative_path) + for kernel_name, value in CudaKernelParamCache.cache.items(): + if asm_file := value["asm"]: + asm_files.append(asm_file) + + cubin_file = value[get_cpp_wrapper_cubin_path_name()] + if ( + config.aot_inductor.emit_multi_arch_kernel + and device_type == "cuda" + ): + current_arch = _nvcc_arch_as_compile_option() + cmd = ( + f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file} " + # Triton only allows generating PTX version as same as the current arch + f"-gencode arch=compute_{current_arch},code=compute_{current_arch} " + # Include SASS for the current specific arch + f"-gencode arch=compute_{current_arch},code=sm_{current_arch} " ) - raise + try: + subprocess.run( + cmd.split(), + capture_output=True, + text=True, + check=True, + ) + except subprocess.CalledProcessError as e: + print( + f"{cmd} failed with:\nstdout:\n{e.stdout}\nstderr:\n{e.stderr}", + file=sys.stderr, + ) + raise - if config.aot_inductor.embed_kernel_binary: - # Embed cubin files into model.so using objcopy - cubins_o.append( - convert_cubin_to_obj(cubin_file, kernel_name, ld, objcopy) - ) + if config.aot_inductor.embed_kernel_binary: + # Embed cubin files into model.so using objcopy + cubins_o.append( + convert_cubin_to_obj(cubin_file, kernel_name, ld, objcopy) + ) output_name, output_dir = get_name_and_dir_from_output_file_path(output_so) so_build_options = CppTorchDeviceOptions( @@ -3757,6 +3774,7 @@ def get_kernel_binary_remote_cache( return None @classmethod + @lru_cache(None) def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]: """ Writes source code into a file with dst_file_ext as the file extension. diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 828050d6da14..92ee9e28be74 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -34,6 +34,7 @@ import torch.fx from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from torch.utils import _pytree as pytree +from torch.utils._config_module import ConfigModule from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter @@ -367,6 +368,7 @@ def cpp_global_scratch( device_op_overrides_dict: dict[str, DeviceOpOverrides] = {} custom_backend_passes: dict[str, Optional[CustomGraphModulePass]] = {} +custom_backend_codegen_configs: dict[str, Optional[ConfigModule]] = {} # The code generated by Inductor consists of two main parts: kernel code and wrapper code. @@ -396,11 +398,20 @@ def register_backend_for_device( device_wrapper_codegen: WrapperConstructor, device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None, device_custom_pass: Optional[CustomGraphModulePass] = None, + device_custom_config: Optional[ConfigModule] = None, ) -> None: device_codegens[device] = DeviceCodegen( device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen ) custom_backend_passes[device] = device_custom_pass + if device_custom_config: + assert ( + isinstance(device_custom_config, ConfigModule) + and device_custom_config is not config + ), ( + f"{device_custom_config=} cannot be the same as the default inductor config {config=}" + ) + custom_backend_codegen_configs[device] = device_custom_config class BackendFeature(Enum): @@ -463,6 +474,14 @@ def get_custom_backend_pass_for_device(device: str) -> Optional[CustomGraphModul return custom_backend_passes[device] if device in custom_backend_passes else None +def get_custom_backend_config_for_device(device: str) -> Optional[ConfigModule]: + return ( + custom_backend_codegen_configs[device] + if device in custom_backend_codegen_configs + else None + ) + + @functools.cache def init_backend_registration() -> None: from .cpp import CppScheduling diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 06467f06fc02..12584284631b 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -24,6 +24,7 @@ from ..._dynamo.utils import counters from .. import config, cpp_builder, cpu_vec_isa, ir, metrics +from ..debug import set_kernel_post_grad_provenance_tracing from ..loop_body import LoopBody from ..scheduler import ( BaseSchedulerNode, @@ -43,7 +44,6 @@ is_welford_reduction, parallel_num_threads, Placeholder, - set_kernel_post_grad_provenance_tracing, sympy_index_symbol, sympy_index_symbol_with_prefix, sympy_product, @@ -5191,7 +5191,7 @@ def define_kernel(self, src_code, nodes, kernel_args=None): ) kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()]) # below add provenance tracing info for cpu CppKernel types - if config.trace.enabled: + if config.trace.provenance_tracking: set_kernel_post_grad_provenance_tracing(nodes, kernel_name) kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel" @@ -5282,8 +5282,11 @@ def codegen_group(self, name=None) -> str: arg_defs, _, _ = self.args.cpp_argdefs() arg_defs = ",\n".ljust(25).join(arg_defs) func_export_decl = get_export_declaration() + inline_attr = ( + "C10_ALWAYS_INLINE_ATTRIBUTE" if config.cpp.force_inline_kernel else "" + ) code.writeline( - f'extern "C" {func_export_decl} void {kernel_decl_name}({arg_defs})' + f'extern "C" {func_export_decl} void {inline_attr} {kernel_decl_name}({arg_defs})' ) # 3. Function body diff --git a/torch/_inductor/codegen/cpp_flex_attention_template.py b/torch/_inductor/codegen/cpp_flex_attention_template.py index 64e11b00fcc0..80fd3014a643 100644 --- a/torch/_inductor/codegen/cpp_flex_attention_template.py +++ b/torch/_inductor/codegen/cpp_flex_attention_template.py @@ -814,7 +814,7 @@ def modification(self, subgraph_buffer, output_name, output_idx): from ..loop_body import LoopBody from ..utils import sympy_index_symbol_with_prefix, SymT from ..virtualized import V - from .cpp import CppKernelProxy, KernelGroup + from .cpp import CppKernelProxy, KernelGroup, ParallelDepth kernel_group = KernelGroup() kernel_input_args = { @@ -883,7 +883,15 @@ def fn(*args): var_sizes_list.append((var_sizes, ())) cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) - kernel_group.finalize_kernel(cpp_kernel_proxy, []) + + def max_parallel_depth(): + return ParallelDepth(parallel_depth=0, start_depth=0) + + # This loop is not parallelized since it is not the outermost loop. + with patch.object( + cpp_kernel_proxy.loop_nest, "max_parallel_depth", max_parallel_depth + ): + kernel_group.finalize_kernel(cpp_kernel_proxy, []) output_code = kernel_group.loops_code.getvalue() var_q_symbol, var_kv_symbol = self.block_vars diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index cbca6d9fe5d2..fa880d35366c 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -21,7 +21,7 @@ from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.symbol import symbol_is_type, SymT -from .. import config, ir +from .. import config, cpp_builder, ir from ..utils import ( _align, aoti_model_name_from_config, @@ -119,7 +119,12 @@ def _generate_temporary_array_pointer( # e.g. const double** is possible, but not const double* const*. This means # that an array containing pointers must _already_ be properly const-qualified # by the c_type, and not add additional const-ness. - ptr_call = "data()" if force_mutable or c_type.endswith("*") else "cbegin()" + # MSVC does not support implicitly converting a const iterator to a const pointer. + ptr_call = ( + "data()" + if force_mutable or c_type.endswith("*") or cpp_builder.is_msvc_cl() + else "cbegin()" + ) return ( f"std::array<{c_type}, {len(elements)}>{{{', '.join(elements)}}}.{ptr_call}" ) @@ -630,10 +635,10 @@ def write_wrapper_decl(self): ), "Expect all constants to be Tensor" for idx, constants_key in enumerate(V.graph.constants.keys()): if V.graph.aot_mode: - # Weights are stored in constants_ and owned by RAIIAtenTensorHandle there. + # Weights are stored in constants_ and owned by ConstantHandle there. # Don't call std::move here because it will cause constants_ to lose the ownership. self.prefix.writeline( - f"""[[maybe_unused]] auto {constants_key} = constants_->at({idx});""" + f"""[[maybe_unused]] auto& {constants_key} = constants_->at({idx});""" ) else: # Append constants as inputs to the graph diff --git a/torch/_inductor/codegen/cpp_wrapper_mps.py b/torch/_inductor/codegen/cpp_wrapper_mps.py index 0b87a0f03795..b953927f52be 100644 --- a/torch/_inductor/codegen/cpp_wrapper_mps.py +++ b/torch/_inductor/codegen/cpp_wrapper_mps.py @@ -7,11 +7,16 @@ from ..ir import GraphPartitionSignature from ..virtualized import V +from .cpp_wrapper_cpu import CppWrapperCpu from .cpp_wrapper_gpu import CppWrapperGpu from .wrapper import PythonWrapperCodegen class CppWrapperMps(CppWrapperGpu): + """ + Generates cpp wrapper for running on MPS and calls metal kernels + """ + def __init__(self) -> None: super().__init__() self._used_kernel_names: OrderedSet[str] = OrderedSet() @@ -29,8 +34,15 @@ def _generate_kernel_call_helper( self, kernel_name: str, call_args: list[str], - arg_types: Optional[list[type]] = None, - **kwargs: dict[str, Any], + *, + device: Optional[torch.device] = None, + triton: bool = True, + arg_types: Optional[tuple[Any, ...]] = None, + raw_keys: Optional[tuple[Any, ...]] = None, + raw_args: Optional[tuple[Any, ...]] = None, + triton_meta: Optional[dict[str, Any]] = None, + graph_name: str = "", + original_fxnode_name: Optional[str] = None, ) -> None: """ Generates MPS kernel call code. It should look something like: @@ -46,17 +58,34 @@ def _generate_kernel_call_helper( }); ``` """ + device = device or V.graph.get_current_device_or_throw() + if device.type == "cpu": + # Even in CppWrapperGpu, we may see cpp kernels + return CppWrapperCpu._generate_kernel_call_helper( + self, + kernel_name, + call_args, + device=device, + triton=triton, + arg_types=arg_types, + raw_keys=raw_keys, + raw_args=raw_args, + triton_meta=triton_meta, + ) + + assert device.type == "mps" + assert arg_types is not None new_args = [] for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])): if isinstance(arg_type, torch.dtype): new_args.append( - f"aoti_torch_mps_set_arg_tensor({kernel_name}_handle, {idx}, {arg});\n" + f"aoti_torch_mps_set_arg_tensor({kernel_name}_handle, {idx}, {arg});" ) elif arg_type in (int, sympy.core.symbol.Symbol): new_args.append( - f"aoti_torch_mps_set_arg_int({kernel_name}_handle, {idx}, {arg});\n" + f"aoti_torch_mps_set_arg_int({kernel_name}_handle, {idx}, {arg});" ) else: raise NotImplementedError( @@ -81,28 +110,26 @@ def _generate_kernel_call_helper( "cpp", ) with debug_printer_manager: - self.writeline(self.wrap_kernel_call(kernel_name, new_args)) - - def wrap_kernel_call(self, name: str, call_args: list[str]) -> str: - lib_name = name[: -len("_func")] - calling_args = " ".join(call_args) - - kernel_call_str = "" + self.write_mps_kernel_call(kernel_name, new_args) + def write_mps_kernel_call(self, name: str, call_args: list[str]) -> None: # Only add handle definition if the kernel is not already used + lib_name = name[: -len("_func")] if name not in self._used_kernel_names: self._used_kernel_names.add(name) - kernel_call_str += f""" - auto {name} = {lib_name}.getKernelFunction("generated_kernel"); - auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get()); - """ - kernel_call_str += f""" - {name}->runCommandBlock([&] {{ - {name}->startEncoding(); - {calling_args} - }}); - """ - return kernel_call_str + + self.writeline( + f'auto {name} = {lib_name}.getKernelFunction("generated_kernel");' + ) + self.writeline( + f"auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get());" + ) + + self.writeline(f"{name}->runCommandBlock([&] {{") + self.writeline(f" {name}->startEncoding();") + for call_arg in call_args: + self.writeline(f" {call_arg}") + self.writeline("});") @staticmethod def get_device_include_path(device: str) -> str: diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index f419ada67e1a..224f0d2a423d 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -631,16 +631,26 @@ def hash_key(self) -> str: """ Return kernel hash key that does not depend on swizzle. """ + swizzle_str: str = ( + str(self.info_kwargs.get("swizzle")) + if isinstance(self.info_kwargs, dict) + else "None" + ) return "-".join( [ self.category, self.bmreq.hash_key, - str(self.info_dict().get("swizzle")), + swizzle_str, ] ) def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: - """Information returned here is logged to the autotune log file when that is enabled.""" + """ + Information returned here is logged to the autotune log file when that is enabled. + + In general, we should avoid calling this function as it is expensive to compute, + and can add up very fast. + """ if self.info_kwargs is not None and "op" in self.info_kwargs: op: Any = self.info_kwargs["op"] return { diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py index 7ed67b0daa49..cc03ccbdda86 100644 --- a/torch/_inductor/codegen/cuda/cuda_template.py +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -3,14 +3,15 @@ import hashlib import itertools from dataclasses import dataclass -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING, Union from typing_extensions import override from unittest.mock import patch import sympy import torch -from torch._inductor.utils import Placeholder +from torch._inductor import config +from torch._inductor.utils import clear_on_fresh_cache, Placeholder from torch._logging import getArtifactLogger from ...autotune_process import CUDABenchmarkRequest, TensorMeta @@ -38,8 +39,12 @@ class ArgInfo: ty: str +@clear_on_fresh_cache class CUDATemplate(KernelTemplate): index_counter = itertools.count() + # dict of cache key to (code, size_args) + code_cache: dict[str, tuple[str, tuple[int, ...]]] = {} + cache_clear = staticmethod(code_cache.clear) def __init__( self, @@ -49,15 +54,15 @@ def __init__( input_reorder: Optional[list[int]] = None, ) -> None: """ - - Baseclass for CUDA C++ Templates, derived from KernelTemplate. Not to be instantiated directly. + Baseclass for CUDA C++ Templates, derived from KernelTemplate. + Not to be instantiated directly. Args: name (str): The name of the CUDATemplate object. input_nodes (List[IRNode]): A list of input IRNodes. layout (Layout): The layout of the output buffer / tensor. - input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes. - + input_reorder (Optional[List[int]]): An optional list that specifies + the order of the input nodes. """ super().__init__(name) self.input_nodes = input_nodes @@ -65,34 +70,60 @@ def __init__( self.input_reorder = input_reorder self.layout = layout + @classmethod + @functools.lru_cache(None) + def _template_from_string(cls, source: str) -> Any: + return KernelTemplate._template_from_string(source) + @staticmethod def supports_epilogue_fusion(op: GemmOperation) -> bool: return False - def generate( # type: ignore[override] - self, - description, - **kwargs, - ) -> CUDATemplateCaller: + def make_key(self, name: str, input_key: str, layout_repr: str) -> str: """ - Generates the CUDA template caller object for the given GEMM template and operation. This CUDATemplateCaller - may be used to call and benchmark the generated CUDA kernel in a standalone manner to enable Autotuning. + Make a key for the code cache. The idea of the method is to cache + everything that matters but doesn't include runtime param values, i.e., + self.get_runtime_arg_values(). Args: - kwargs: Additional keyword arguments. - - Returns: - A CUDATemplateCaller object representing the generated CUDA template caller. + kwargs: Additional keyword arguments. Including op (GemmOperation). + """ + return hashlib.sha256( + str( + ( + input_key, + self.input_reorder, + # output layout, same as self.output_node.get_layout() + layout_repr, + self.get_runtime_arg_info(), + name, + ) + ).encode("utf-8") + ).hexdigest() + + def generate_code_and_args( + self, name: str, input_key: str, layout_repr: str, **kwargs + ) -> tuple[str, tuple[int, ...]]: + """ + Generate code and args with caching. We cache the code even if runtime + args are different. """ + key: Optional[str] = None + if config.cuda.enable_caching_codegen: + key = self.make_key(name=name, input_key=input_key, layout_repr=layout_repr) + + if key is not None and key in self.code_cache: + code, size_args = self.code_cache[key] + extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs)) + return code, extra_args + kernel_name = str(Placeholder.KERNEL_NAME) - with ( - patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)), - CUDATemplateKernel( - kernel_name=kernel_name, - runtime_arg_info=self.get_runtime_arg_info(), - runtime_arg_values=self.get_runtime_arg_values(**kwargs), - ) as kernel, - ): + kernel = CUDATemplateKernel( + kernel_name=kernel_name, + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) + with patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)): code = self.render(kernel=kernel, **kwargs) _, call_args, _, _ = kernel.args.python_argdefs() autotuning_log.debug("Generated Code:\n%s", code) @@ -117,8 +148,45 @@ def generate( # type: ignore[override] ) V.graph.sizevars.size_hints(map(sympy.expand, call_args[len(expected_args) :])) size_args = V.graph.sizevars.size_hints(kernel.get_dynamic_shape_args()) + + if key is not None: + self.code_cache[key] = code, size_args + + # extra args has runtime params, which shouldn't be cached extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs)) + return code, extra_args + + def generate( # type: ignore[override] + self, + name: str, + description: str, + input_key: str, + layout_repr: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], + **kwargs, + ) -> CUDATemplateCaller: + """ + Generates the CUDA template caller object for the given GEMM template and operation. + This CUDATemplateCaller may be used to call and benchmark the generated CUDA kernel + in a standalone manner to enable Autotuning. + + Args: + description: op name followed by swizzle. + kwargs: Additional keyword arguments. + + Returns: + A CUDATemplateCaller object representing the generated CUDA template caller. + """ + code, extra_args = self.generate_code_and_args( + name=name, + input_key=input_key, + layout_repr=layout_repr, + **kwargs, + ) + + # not caching since kernel name is needed below kernel_hash = hashlib.sha256(code.encode("utf-8")).hexdigest()[:8] kernel_name = f"cutlass_{kernel_hash}" code = code.replace(self.name, kernel_name) @@ -126,8 +194,8 @@ def generate( # type: ignore[override] # create the BenchmarkRequest bmreq = CUDABenchmarkRequest( kernel_name=kernel_name, - input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), - output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + input_tensor_meta=input_tensor_meta, + output_tensor_meta=output_tensor_meta, extra_args=extra_args, source_code=code, ) diff --git a/torch/_inductor/codegen/cuda/cutlass_utils.py b/torch/_inductor/codegen/cuda/cutlass_utils.py index e2251b42fc7e..7ca33ea779cc 100644 --- a/torch/_inductor/codegen/cuda/cutlass_utils.py +++ b/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -45,7 +45,10 @@ def move_cutlass_compiled_cache() -> None: else: import cutlass as python_cutlass # type: ignore[import-not-found] # noqa: F401 - if not os.path.exists(python_cutlass.CACHE_FILE): + # Check if the CACHE_FILE attribute exists in python_cutlass and if the file exists + if not hasattr(python_cutlass, "CACHE_FILE") or not os.path.exists( + python_cutlass.CACHE_FILE + ): return try: @@ -125,7 +128,7 @@ def path_join(path0, path1): if tmp_cutlass_full_path not in sys.path: def link_and_append(dst_link, src_path, parent_dir): - if os.path.exists(dst_link): + if os.path.lexists(dst_link): assert os.path.islink(dst_link), ( f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again." ) diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index a38b846f7909..bdecc07d69a5 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -10,6 +10,7 @@ import torch import torch.utils._pytree as pytree +from torch._inductor.autotune_process import TensorMeta from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops from torch._inductor.runtime.runtime_utils import dynamo_timed from torch._inductor.scheduler import BaseSchedulerNode @@ -562,6 +563,16 @@ def _add_cutlass_gemm_choices( """ ops = self.gen_ops() + + # pre-computation + layout_repr: str = str(layout) + input_tensor_meta: Union[TensorMeta, list[TensorMeta]] = ( + TensorMeta.from_irnodes(self.input_nodes) + ) + output_tensor_meta: Union[TensorMeta, list[TensorMeta]] = ( + TensorMeta.from_irnodes(self.output_node) + ) + with dynamo_timed("CUTLASSGemmTemplate.maybe_append_choice"): for name, op in ops: for ( @@ -569,7 +580,15 @@ def _add_cutlass_gemm_choices( ) in inductor_cuda_config.cutlass_max_profiling_swizzle_options: description = f"{name} swizzle={swizzle}" self.maybe_append_choice( - choices, description=description, op=op, swizzle=swizzle + choices, + op=op, + name=name, + description=description, + input_key=self.cache_key, + layout_repr=layout_repr, + input_tensor_meta=input_tensor_meta, + output_tensor_meta=output_tensor_meta, + swizzle=swizzle, ) if len(ops) == 0: diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index f8176c191fd4..7060f857828e 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -929,7 +929,7 @@ def format_threads(threads: list[str], kwarg: str) -> str: wrapper.generate_kernel_call( name, args, - device=torch.device("cpu"), # TODO: Fix me, MPS does not expose streams now + device=torch.device("mps"), triton=False, arg_types=arg_types, ) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 90bee26c0924..7ac967bbe0b0 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -41,6 +41,7 @@ if TYPE_CHECKING: from ..ir import IRNode +from ..debug import set_kernel_post_grad_provenance_tracing from ..optimize_indexing import indexing_dtype_strength_reduction from ..runtime.runtime_utils import green_text, yellow_text from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse @@ -51,7 +52,6 @@ IndentedBuffer, Placeholder, prefix_is_reduction, - set_kernel_post_grad_provenance_tracing, sympy_index_symbol, sympy_product, sympy_subs, @@ -1453,7 +1453,7 @@ def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures): with V.set_kernel_handler(kernel): src_code = kernel.codegen_kernel() kernel_name = self.define_kernel(src_code, node_schedule, kernel) - if config.trace.enabled: + if config.trace.provenance_tracking: set_kernel_post_grad_provenance_tracing( node_schedule, # type: ignore[arg-type] kernel_name, @@ -1639,11 +1639,16 @@ def _codegen_single_template( partial_code.finalize_hook(subgraph_name, strict=False) with kernel.set_subgraph_body(""): - if isinstance(partial_code, str): - src_code = partial_code - else: + if not isinstance(partial_code, str): partial_code.finalize_hook("") - src_code = partial_code.code + + if isinstance(partial_code, str): + src_code = partial_code + else: + # Ensure all hooks are finalized before the kernel is defined. + # Note: some of these hooks may have been registered by a kernel subclass + src_code = partial_code.finalize_remaining() + node_schedule = [*prologue_nodes, template_node, *epilogue_nodes] if config.benchmark_kernel: @@ -1659,7 +1664,7 @@ def _codegen_single_template( kernel.kernel_name = self.define_kernel(src_code, node_schedule, kernel) - if config.trace.enabled: + if config.trace.provenance_tracking: set_kernel_post_grad_provenance_tracing( node_schedule, kernel.kernel_name ) @@ -1844,7 +1849,7 @@ def codegen_combo_kernel(self, combo_kernel_node): for src_code, kernel, _ in kernel_code_list: kernel_name = self.define_kernel(src_code, [combo_kernel_node], kernel) # dump provenance node info for ComboKernelNode/ForeachKernel type - if config.trace.enabled: + if config.trace.provenance_tracking: set_kernel_post_grad_provenance_tracing( combo_kernel_node.snodes, kernel_name ) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index e7726263714f..683282fa9c5a 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -40,6 +40,7 @@ from .. import async_compile, config, ir from ..codecache import output_code_log +from ..debug import set_kernel_post_grad_provenance_tracing from ..ir import IRNode, ReinterpretView from ..runtime import triton_heuristics from ..runtime.hints import DeviceProperties @@ -50,7 +51,6 @@ IndentedBuffer, is_codegen_graph_partition_subgraph, LineContext, - set_kernel_post_grad_provenance_tracing, sympy_product, sympy_str, sympy_subs, @@ -479,7 +479,7 @@ def codegen(self, code: IndentedBuffer) -> None: kernel_name = node.get_kernel_name() device = d.type if (d := node.get_device()) else V.graph.device_type # set provenance tracing kernel mapping for ExternKernel types - if config.trace.enabled: + if config.trace.provenance_tracking: set_kernel_post_grad_provenance_tracing(node, kernel_name, is_extern=True) self.wrapper._generate_extern_kernel_out_helper( kernel_name, @@ -2541,7 +2541,9 @@ def _generate_kernel_call_helper( original_fxnode_name=None, ): device = device or V.graph.get_current_device_or_throw() - if not (triton or device.type != "cpu"): + if not ( + triton or device.type not in ("cpu", "mps") + ): # TODO: Fix me, MPS does not expose streams now self.writeline(self.wrap_kernel_call(kernel_name, call_args)) return diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index caaf43dba590..f93485333d30 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -4,7 +4,6 @@ import heapq import importlib -import itertools import logging import operator import sys @@ -149,9 +148,8 @@ def is_gemm_like(node: Optional[Union[IRNode, Operation]]) -> bool: return True if ( - hasattr(node, "python_kernel_name") - and node.python_kernel_name == "extern_kernels.mm" - ): + python_kernel_name := getattr(node, "python_kernel_name", None) + ) and "extern_kernels" in python_kernel_name: return True return False @@ -189,15 +187,23 @@ def _group_name(snode, with_bufs=False) -> str: def _reorder_communication_preserving_peak_memory_internal( snodes: list[BaseSchedulerNode], ) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]: - from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node - - original_snodes_num = len(snodes) """ Internal testing helper that also returns debug info. Returns: - reordered snodes list - dict {snode: ReorderInfo} """ + has_collectives = False + for snode in snodes: + if contains_collective(snode): + has_collectives = True + break + if not has_collectives: + return snodes, {} + + from torch._inductor.scheduler import GroupedSchedulerNode + + original_snodes_num = len(snodes) # heuristic to avoid degenerating to quadratic time graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) @@ -208,7 +214,8 @@ def _reorder_communication_preserving_peak_memory_internal( snodes, name_to_freeable_input_buf, graph_outputs ) runtimes = {snode: estimate_op_runtime(snode) for snode in snodes} - snode_to_curr_memory = dict(zip(snodes, curr_memory)) + _curr_memory = dict(zip(snodes, curr_memory)) + _curr_memory[None] = 0 # type: ignore[index] # debug stats stats: dict[BaseSchedulerNode, ReorderInfo] = {} @@ -232,153 +239,151 @@ def accumulate_time(_snode): _temp_group_visit_leaves(snode, accumulate_time) return max(0, comm_time - compute_time) - MOVE_LIMIT = len(snodes) * 100 total_moves = 0 - # TODO - experiment with whether this limit is useful, setting `len(snodes)` disables it - PER_COLLECTIVE_PREFETCH_LIMIT = len(snodes) - if config.reorder_prefetch_limit is not None: - PER_COLLECTIVE_PREFETCH_LIMIT = config.reorder_prefetch_limit # Dicts to keep track of "next" and "previous" as double-linked structure during grouping - _prev: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]] = {} - _next: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]] = {} + _prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} + _next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} for i, snode in enumerate(snodes): _prev[snode] = snodes[i - 1] if i > 0 else None _next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None - - gsnodes: list[GroupedSchedulerNode] = [ - GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True) - for snode in snodes - ] - for i, gsnode in enumerate(gsnodes): - snode = gsnode.snodes[0] # type: ignore[attr-defined] - if contains_collective(snode): - reorder_info = stats[snode] = ReorderInfo() + _curr_memory = dict(zip(snodes, curr_memory)) + _curr_memory[None] = 0 # type: ignore[index] + + _head = snodes[0] + + def _group_nodes(head, tail): + ret = [] + n = head + while True: + if n is not None: + ret.append(n) + if n == tail: + break + n = _next[n] + return ret + + def _group_names(head, tail): + ret = "" + for n in _group_nodes(head, tail): + if ret: + ret += "~" + ret += n.get_name() + return ret + + curr = _head + while _next[curr] is not None: + if contains_collective(curr): + reorder_info = stats[curr] = ReorderInfo() reorder_info.initial_exposed = reorder_info.final_exposed = ( - exposed_communication_time(snode, snodes[i + 1 :]) + exposed_communication_time(curr, _group_nodes(_next[curr], None)) ) - if total_moves >= MOVE_LIMIT: - reorder_info.limiting_factor = "move limit" - continue - for j in range(i - 1, -1, -1): - prev_gsnode = gsnodes[j] - if len(prev_gsnode.snodes) == 0: - continue - - if j < max(0, i - PER_COLLECTIVE_PREFETCH_LIMIT): - reorder_info.limiting_factor = "prefetch limit" - break - if contains_collective(prev_gsnode): + candidate = _prev[curr] + group_head = curr + group_tail = curr + group_peak_memory = _curr_memory[curr] + while candidate is not None: + if contains_collective(candidate): reorder_info.limiting_factor = "collective ordering" break - dep_names = OrderedSet([s.name for s in snode.unmet_dependencies]) - prev_outs = prev_gsnode.get_outputs() + group = GroupedSchedulerNode( + curr.scheduler, + _group_nodes(group_head, group_tail), + temp_grouping=True, + ) + + data_deps = {s.name: s for s in group.unmet_dependencies} + candidate_outs = candidate.get_outputs() data_dep = None - for o in prev_outs: - if o.get_name() in dep_names: - data_dep = o.get_name() + for o in candidate_outs: + if d := data_deps.get(o.get_name(), None): + if isinstance(d, WeakDep) and d.is_fake: + continue + data_dep = d break if data_dep is not None: - def is_groupable(prev_gsnode): + def is_groupable(candidate): # preserve ordering - if contains_collective(prev_gsnode): - return False - - if contains_gemm_like(prev_gsnode): - return False - return True - - if is_groupable(prev_gsnode): - new_snodes = prev_gsnode.snodes + gsnode.snodes - init_group_node(gsnode, gsnode.scheduler, new_snodes) - prev_gsnode.snodes = [] + if contains_collective(candidate): + return False, "contains_collective" + + if contains_gemm_like(candidate): + return False, "contains_gemm_like" + return True, None + + is_grp, grp_reason = is_groupable(candidate) + if is_grp: + group_head = candidate + group_peak_memory = max( + group_peak_memory, _curr_memory[candidate] + ) reorder_info.grouped += 1 - reorder_info.grouped_info = gsnode.get_name() + reorder_info.grouped_info = _group_names(group_head, group_tail) + candidate = _prev[candidate] continue else: msg = ( - f"data dependency {data_dep}(dep_names:{dep_names})" - f" prev_gsnode.outputs:{[o.get_name() for o in prev_outs]}" + f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" + f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" + f"dep on {_group_names(group_head, group_tail)}" + f"\n non_group_reason:{grp_reason}" ) reorder_info.limiting_factor = msg break - if peak_memory - curr_memory[j] < curr_memory[j - 1] - curr_memory[j]: + delta_memory_candidate = ( + _curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index] + ) + + if group_peak_memory - delta_memory_candidate > peak_memory: reorder_info.limiting_factor = "peak memory" break - if reorder_info.final_exposed > runtimes[snode]: - reorder_info.limiting_factor = "sufficient overlapping" - break + reorder_info.moves += 1 total_moves += 1 - # swapping nodes j and j+1 affects curr memory at j only - # j_plus_one_alloc = curr_memory[j + 1] - curr_memory[j] - # j_alloc = curr_memory[j] - curr_memory[j - 1] - # curr_memory[j] = curr_memory[j] - j_alloc + j_plus_one_alloc - def swap_curr_memory_with_previous( - snode_j_plus_one, snode_j, snode_j_minus_one - ): - curr_memory_j_plus_one = snode_to_curr_memory[snode_j_plus_one] - curr_memory_j = snode_to_curr_memory[snode_j] - curr_memory_j_minus_one = ( - snode_to_curr_memory[snode_j_minus_one] - if snode_j_minus_one is not None - else 0 - ) - j_plus_one_alloc = curr_memory_j_plus_one - curr_memory_j - j_alloc = curr_memory_j - curr_memory_j_minus_one - snode_to_curr_memory[snode_j] = ( - curr_memory_j - j_alloc + j_plus_one_alloc - ) - - # Recompuing curr_mem for swapping grouped nodes j (group A) and j + 1 (group B) - # swap([A0, A1, A2], [B0, B1]) --> [B0, B1], [A0, A1, A2] - # decomposing to: - # swap(A2, B0) -> A0, A1, B0, A2, B1 - # swap(A2, B1) -> A0, A1, B0, B1, A2 - # swap(A1, B0) -> A0, B0, A1, B1, A2 - # swap(A1, B1) -> A0, B0, B1, A1, A2 - # swap(A0, B0) -> B0, A0, B1, A1, A2 - # swap(A0, B1) -> B0, B1, A0, A1, A2 - for _j in range(len(gsnodes[j].snodes) - 1, -1, -1): # group A - snode_j = gsnodes[j].snodes[_j] - for _i, snode_i in enumerate(gsnode.snodes): # group B - swap_curr_memory_with_previous( - snode_j_plus_one=snode_i, - snode_j=snode_j, - snode_j_minus_one=_prev[snode_j], - ) + mem_deltas = {} + for n in [candidate, *_group_nodes(group_head, group_tail)]: + mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index] + # swap (candidate, group_head...group_tail) + # Before: + # candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next + # After: + # candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next + # 0 + candidate_prev = _prev[candidate] + if candidate_prev: + _next[candidate_prev] = group_head + _prev[group_head] = candidate_prev + + # 2 + group_tail_next = _next[group_tail] + if group_tail_next: + _prev[group_tail_next] = candidate + _next[candidate] = group_tail_next + + # 1 + _prev[candidate] = group_tail + _next[group_tail] = candidate + + if _head == candidate: + _head = group_head - # Update _next and _prev for swap [snode_j, snode_i] -> [snode_i, snode_j] - first = snode_j - second = snode_i - first_prev = _prev[first] - second_next = _next[second] - if first_prev: - _next[first_prev] = second - _prev[second] = first_prev - - if second_next: - _prev[second_next] = first - _next[first] = second_next - - _next[second] = first - _prev[first] = second - - tmp = gsnodes[j] - gsnodes[j] = gsnodes[j + 1] - gsnodes[j + 1] = tmp reorder_info.final_exposed = exposed_communication_time( - snode, - itertools.chain( - gsnode.snodes[1:], *[n.snodes for n in gsnodes[j + 1 :]] - ), + curr, _group_nodes(_next[curr], None) ) + # Recompute curr_memory + _prev_curr_memory = _curr_memory[_prev[group_head]] # type: ignore[index] + for n in _group_nodes(group_head, candidate): + _curr_memory[n] = _prev_curr_memory = ( + _prev_curr_memory + mem_deltas[n] + ) + candidate = _prev[group_head] + curr = _next[curr] # type: ignore[assignment] node_stats = stats improvement = {snode: node_stats[snode].improvement for snode in node_stats} @@ -426,17 +431,13 @@ def swap_curr_memory_with_previous( reorder_log_str += str(headers) + "\n" reorder_log_str += "\n".join(map(str, rows)) - grouping_logs: list[str] = [] - flatten_gsnodes: list[BaseSchedulerNode] = [] - for i, gsnode in enumerate(gsnodes): - if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping: - flatten_gsnodes.extend(gsnode.snodes) - else: - flatten_gsnodes.append(gsnode) - - grouping_log_str = "\n".join(grouping_logs) - reorder_log_str += "\n" - reorder_log_str += grouping_log_str + new_snodes = _group_nodes(_head, None) + assert len(new_snodes) == original_snodes_num + new_peak_memory, curr_memory = estimate_peak_memory( + new_snodes, name_to_freeable_input_buf, graph_outputs + ) + reorder_log_str += f"\n peak_memory_before:{peak_memory}" + reorder_log_str += f"\n peak_memory_after:{new_peak_memory}" overlap_log.info(reorder_log_str) trace_structured( @@ -448,8 +449,7 @@ def swap_curr_memory_with_previous( payload_fn=lambda: reorder_log_str, ) - assert len(flatten_gsnodes) == original_snodes_num - return flatten_gsnodes, stats + return new_snodes, stats def _schedule_for_comm( @@ -623,7 +623,9 @@ def decide_global_ordering_of_comms( # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm mutating_buf = next(iter(comm_nodes[i].get_buffer_names())) for buf in comm_nodes[i - 1].get_buffer_names(): - comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf)) + comm_nodes[i].add_fake_dep( + WeakDep(buf, mutating_buf=mutating_buf, is_fake=True) + ) return nodes @@ -640,66 +642,166 @@ class SinkWaitInfo: def _sink_waits_iterative_internal( snodes: list[BaseSchedulerNode], ) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]: - from torch._inductor.scheduler import GroupedSchedulerNode, init_group_node + from torch._inductor.scheduler import GroupedSchedulerNode + + original_snodes_num = len(snodes) + if original_snodes_num == 0: + return snodes, {} + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf( + snodes, graph_inputs + ) + peak_memory, curr_memory = estimate_peak_memory( + snodes, name_to_freeable_input_buf, graph_outputs + ) - n = len(snodes) stats: dict[BaseSchedulerNode, SinkWaitInfo] = {} - gsnodes: list[GroupedSchedulerNode] = [ - GroupedSchedulerNode(snode.scheduler, [snode], temp_grouping=True) - for snode in snodes - ] - for i in range(n - 1, -1, -1): - gsnode = gsnodes[i] - if contains_wait(gsnode): - info = stats[gsnode.snodes[0]] = SinkWaitInfo() - for j in range(i + 1, n): - wait_gsnode = gsnodes[j - 1] - wait_outs = wait_gsnode.get_outputs() - next_gsnode = gsnodes[j] - dep_names = OrderedSet([s.name for s in next_gsnode.unmet_dependencies]) + _prev: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} + _next: dict[Optional[BaseSchedulerNode], Optional[BaseSchedulerNode]] = {} + _head = snodes[0] + for i, snode in enumerate(snodes): + _prev[snode] = snodes[i - 1] if i > 0 else None + _next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None + _curr_memory = dict(zip(snodes, curr_memory)) + _curr_memory[None] = 0 # type: ignore[index] + + def _group_nodes(head, tail): + ret = [] + n = head + while True: + if n is not None: + ret.append(n) + if n == tail: + break + n = _next[n] + return ret + + def _group_names(head, tail): + ret = "" + for n in _group_nodes(head, tail): + if ret: + ret += "~" + ret += n.get_name() + return ret + + curr = snodes[-1] + + processed_waits = OrderedSet() # type: ignore[var-annotated] + while _prev[curr] is not None: + if contains_wait(curr) and curr not in processed_waits: + processed_waits.add(curr) + info = stats[curr] = SinkWaitInfo() + candidate = _next[curr] + wait_snode = curr + group_head = curr + group_tail = curr + group_peak_memory = _curr_memory[curr] + while candidate is not None: + group = GroupedSchedulerNode( + wait_snode.scheduler, + _group_nodes(group_head, group_tail), + temp_grouping=True, + ) + group_outs = group.get_outputs() + + data_deps = {s.name: s for s in candidate.unmet_dependencies} data_dep = None - for o in wait_outs: - if o.get_name() in dep_names: - data_dep = o.get_name() + for o in group_outs: + if d := data_deps.get(o.get_name(), None): + if isinstance(d, WeakDep) and d.is_fake: + continue + data_dep = d break # 1. If we have data_dep - we can not swap => trying to group # 2. If swap candidate and current node both contain collectives => trying to group if data_dep is not None or ( both_contain_comms := ( - contains_collective(wait_gsnode) - and contains_collective(next_gsnode) + contains_collective(group) and contains_collective(candidate) ) ): def is_groupable(snode): - return not contains_gemm_like(snode) - - if is_groupable(next_gsnode): - new_snodes = wait_gsnode.snodes + next_gsnode.snodes - init_group_node(next_gsnode, gsnode.scheduler, new_snodes) - wait_gsnode.snodes = [] + # We do not want to group with collectives to not reorder them forward. + if contains_collective(snode): + return ( + False, + f"candidate contains collective {snode.get_name()}", + ) + if contains_gemm_like(snode): + return ( + False, + f"candidate contains gemm_like {snode.get_name()}", + ) + return True, None + + is_grp, grp_reason = is_groupable(candidate) + if is_grp: + group_tail = candidate + group_peak_memory = max( + group_peak_memory, _curr_memory[candidate] + ) info.grouped += 1 - info.grouped_info = _group_name(next_gsnode) + info.grouped_info = _group_names(group_head, group_tail) + candidate = _next[candidate] continue elif (data_dep is None) and both_contain_comms: info.limiting_factor = ( - f"collective ordering {_group_name(wait_gsnode)}" - f" with candidate:{_group_name(next_gsnode)}" + f"collective ordering {_group_names(group_head, group_tail)}" + f" with candidate:{candidate.get_name()}" ) + break else: info.limiting_factor = ( - f"data dependency {data_dep}(dep_names:{dep_names})" - f" candidate:{_group_name(next_gsnode)} dep on {_group_name(wait_gsnode)}" - f" outs:{[o.get_name() for o in wait_outs]}" + f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" + f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" + f"dep on {_group_names(group_head, group_tail)}" + f"\n outs:{[o.get_name() for o in group_outs]}" + f"\n non_group_reason:{grp_reason}" ) break + candidate_delta_memory = ( + _curr_memory[candidate] - _curr_memory[_prev[candidate]] # type: ignore[index] + ) + if group_peak_memory + candidate_delta_memory > peak_memory: + info.limiting_factor = "peak_memory" + break + info.moves += 1 - info.moves_info += f"+{_group_name(next_gsnode)}" + info.moves_info += f"+{candidate.get_name()}" + + # group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next + mem_deltas = {} + for n in [candidate, *_group_nodes(group_head, group_tail)]: + mem_deltas[n] = _curr_memory[n] - _curr_memory[_prev[n]] # type: ignore[index] + # 0: + group_head_prev = _prev[group_head] + if group_head_prev: + _next[group_head_prev] = candidate + _prev[candidate] = group_head_prev + + # 2: + candidate_next = _next[candidate] + if candidate_next: + _prev[candidate_next] = group_tail + _next[group_tail] = candidate_next + + # 1: + _prev[group_head] = candidate + _next[candidate] = group_head + if group_head == _head: + _head = candidate + + # Recompute curr_memory + _prev_curr_memory = _curr_memory[_prev[candidate]] # type: ignore[index] + for n in _group_nodes(candidate, group_tail): + _curr_memory[n] = _prev_curr_memory = ( + _prev_curr_memory + mem_deltas[n] + ) + + candidate = _next[group_tail] + curr = _prev[curr] # type: ignore[assignment] - # Swapping snodes j and j - 1 - tmp = gsnodes[j - 1] - gsnodes[j - 1] = gsnodes[j] - gsnodes[j] = tmp headers = [ "Wait node", "grouped", @@ -732,16 +834,13 @@ def is_groupable(snode): log_str += str(headers) + "\n" log_str += "\n".join(map(str, rows)) overlap_log.info(log_str) - grouping_logs = [] - flatten_snodes = [] - for i, gsnode in enumerate(gsnodes): - grouping_logs.append(f"gsnode[{i}]:{_group_name(gsnode, with_bufs=True)}") - if isinstance(gsnode, GroupedSchedulerNode) and gsnode.temp_grouping: - flatten_snodes.extend(gsnode.snodes) - else: - flatten_snodes.append(gsnode) - grouping_log_str = "\n".join(grouping_logs) - log_str += grouping_log_str + new_snodes = _group_nodes(_head, None) + assert len(new_snodes) == original_snodes_num + new_peak_memory, curr_memory = estimate_peak_memory( + new_snodes, name_to_freeable_input_buf, graph_outputs + ) + log_str += f"\n peak_memory_before:{peak_memory}" + log_str += f"\n peak_memory_after:{new_peak_memory}" trace_structured( "artifact", metadata_fn=lambda: { @@ -750,8 +849,7 @@ def is_groupable(snode): }, payload_fn=lambda: log_str, ) - assert len(flatten_snodes) == n - return flatten_snodes, stats + return new_snodes, stats def sink_waits_iterative( @@ -777,7 +875,9 @@ def node_summary(snode): if len(snodes) == 1: detail = "" if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)): - detail = f" ({snode.node.python_kernel_name})" + outs_str = f"outs:{[o.get_name() for o in snode.get_outputs()]}" + ins_str = f"ins:{[d.name for d in snode.unmet_dependencies]}" + detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}\n ({ins_str})" layouts = [child.node.get_output_spec() for child in snode.get_nodes()] out_tensor_info = ",".join( [ @@ -1352,7 +1452,7 @@ def _create_group_node(snodes_to_group): mutating_buf = next(iter(ag_group_node.get_buffer_names())) for o in prev_ag_wait.get_outputs(): ag_group_node.add_fake_dep( - WeakDep(o.get_name(), mutating_buf=mutating_buf) + WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True) ) prev_ag_wait = wait_group_node @@ -1364,7 +1464,7 @@ def _create_group_node(snodes_to_group): mutating_buf = next(iter(rs_group_node.get_buffer_names())) for o in prev_rs_wait.get_outputs(): rs_group_node.add_fake_dep( - WeakDep(o.get_name(), mutating_buf=mutating_buf) + WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True) ) prev_rs_wait = wait_group_node diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index bfdb9a54e56f..8e712a28a3b0 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -53,6 +53,7 @@ ) from torch._functorch.aot_autograd import ( aot_export_module, + GraphOutputName, make_boxed_func, SerializableAOTDispatchCompiler, ) @@ -429,7 +430,7 @@ def _unlift_graph( from torch.export._unlift import _unlift - outputs = list(gm.graph.nodes)[-1].args[0] + outputs: tuple[torch.fx.Node, ...] = tuple(gm.graph.output_node().args[0]) # type: ignore[arg-type] mutated_outputs = [] buffer_mutations = graph_signature.buffers_to_mutate user_input_mutations = graph_signature.user_inputs_to_mutate @@ -438,10 +439,11 @@ def _unlift_graph( value: Optional[Union[FQN, GraphInputName]] = None if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens): - if out.name in buffer_mutations: - value = buffer_mutations[out.name] - elif out.name in user_input_mutations: - value = user_input_mutations[out.name] + name = GraphOutputName(out.name) + if name in buffer_mutations: + value = buffer_mutations[name] + elif name in user_input_mutations: + value = user_input_mutations[name] mutated_outputs.append(value) @@ -451,8 +453,6 @@ def _unlift_graph( mutated_outputs, pytree.LeafSpec(), None, - state_dict, - {}, ) return unlifted_gm @@ -909,10 +909,37 @@ def _compile_fx_inner( else: log.debug("Failed to generate FX cache key") + if torch._functorch.config.bundled_autograd_cache: + assert mb_compiled_graph is None + assert cache_info is None + # When using bundled autograd cache, we still want + # to use the TritonBundler, but we don't want to save + # the results here. The results will get saved directly + # to AOTAutogradCache. + TritonBundler.begin_compile() + try: + mb_compiled_graph = fx_codegen_and_compile( + gm, example_inputs, inputs_to_check, **graph_kwargs + ) + assert mb_compiled_graph is not None + ( + triton_bundle, + triton_bundler_meta, + ) = TritonBundler.collect() + mb_compiled_graph.set_triton_bundle(triton_bundle) + except (ShortenTraceback, SkipFrame): + raise + except Exception as e: + raise InductorError(e, currentframe()).with_traceback( + e.__traceback__ + ) from None + finally: + TritonBundler.end_compile() + # CACHE BYPASS: Compile the graph, don't save it to the cache # (this can happen either because cache was disabled, or we # determined the input is uncacheable) - if cache_info is None or cache_info["cache_state"] == "bypass": + elif cache_info is None or cache_info["cache_state"] == "bypass": assert mb_compiled_graph is None log.debug( "FX cache bypass reason: %s", @@ -1028,30 +1055,19 @@ def _compile_fx_inner( log.debug("FX codegen and compilation took %.3fs", time.time() - start) - # Dump provenance artifacts for debugging trace - provenance_info = V.debug.log_inductor_triton_kernel_to_post_grad_node_info() - # provenance_info might be None if config.trace.enabled is not set - if provenance_info: - ( - debug_info, - node_mappings, - ) = provenance_info - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "inductor_generated_kernel_to_post_grad_nodes", - "encoding": "json", - }, - payload_fn=lambda: json.dumps(debug_info), - ) - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "inductor_provenance_tracking_node_mappings", - "encoding": "json", - }, - payload_fn=lambda: json.dumps(node_mappings), - ) + if config.trace.provenance_tracking: + # Dump provenance artifacts for debugging trace + provenance_info = torch._inductor.debug.dump_inductor_provenance_info() + # provenance_info might be None if trace.provenance_tracking is not set + if provenance_info: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_provenance_tracking_node_mappings", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(provenance_info), + ) # This message is for printing overview information of inductor mm counts, shapes,etc after lowering if log.isEnabledFor(logging.INFO): @@ -1294,7 +1310,7 @@ def codegen_and_compile( }, payload_fn=lambda: inductor_post_grad_graph_str, ) - if config.trace.enabled: + if config.trace.provenance_tracking: provenance_tracking_json = ( torch.fx.traceback.get_graph_provenance_json(gm.graph) ) @@ -1306,8 +1322,13 @@ def codegen_and_compile( }, payload_fn=lambda: json.dumps(provenance_tracking_json), ) + from torch._inductor.debug import create_mapping_pre_post_grad_nodes + torch._inductor.debug._inductor_post_to_pre_grad_nodes = ( - provenance_tracking_json + create_mapping_pre_post_grad_nodes( + torch._inductor.debug._pre_grad_graph_id, + provenance_tracking_json, + ) ) metrics_context = get_metrics_context() @@ -2147,7 +2168,8 @@ def compile_fx( with ( _use_lazy_graph_module(dynamo_config.use_lazy_graph_module), enable_python_dispatcher(), - torch.fx.traceback.preserve_node_meta(config.trace.enabled), + torch.fx.traceback.preserve_node_meta(config.trace.provenance_tracking), + torch._inductor.debug.reset_provenance_globals(), ): # Pre-grad passes cannot be run if we weren't given a GraphModule. # Dynamo will always produce a GraphModule, but this handles cases @@ -2180,6 +2202,13 @@ def compile_fx( ) torch._inductor.debug._pre_grad_graph_id = id(model_.graph) + if config.trace.provenance_tracking: + for node in model_.graph.nodes: + if node.stack_trace: + torch._inductor.debug._inductor_pre_grad_node_stack_trace[ + node.name + ] = node.stack_trace + model_ = _recursive_pre_grad_passes(model_, example_inputs_) trace_structured( "artifact", diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 293c1b233343..ae2ee6a574c7 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -425,9 +425,6 @@ def prologue_fusion_enabled() -> bool: # enable slow autotuning passes to select gemm algorithms max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1" -# disable decomposek autotune choice for gemm -disable_decompose_k = os.environ.get("TORCHINDUCTOR_DISABLE_DECOMPOSE_K") == "1" - # Modifies the number of autotuning choices displayed, set to None for all autotune_num_choices_displayed: Optional[int] = 10 @@ -574,6 +571,9 @@ def use_autoheuristic(name: str) -> bool: # Threshold to prevent excessive accumulation of ops in one buffer during lowering realize_acc_reads_threshold = 8 +realize_acc_reads_size_threshold: Optional[int] = ( + None # TODO(xuanzh): harden this to make it non optional +) # fallback to eager for random/dropout, this is slow but useful for debugging fallback_random = False @@ -995,7 +995,7 @@ def decide_compile_threads() -> int: annotate_training: bool = os.environ.get("TORCHINDUCTOR_ANNOTATE_TRAINING", "0") == "1" # Enable caching codegen of triton templates. -enable_caching_generated_triton_templates: bool = False +enable_caching_generated_triton_templates: bool = True # Lookup table for overriding autotune configs based on hash of Triton source code autotune_lookup_table: dict[str, dict[str, Any]] = {} @@ -1003,6 +1003,11 @@ def decide_compile_threads() -> int: # config specific to codegen/cpp.py class cpp: + """ + Settings for cpp backend. + This class provides a centralized location for managing cpp backend settings. + """ + # set to torch.get_num_threads() threads = -1 @@ -1118,6 +1123,10 @@ class cpp: # Use a small dequant buffer for wgt of woq int4 size as: [q_group_size, Nr] use_small_dequant_buffer = False + force_inline_kernel = ( + os.environ.get("TORCHINDUCTOR_CPP_FORCE_INLINE_KERNEL", "0") == "1" + ) + class triton: """ @@ -1333,6 +1342,17 @@ class triton: # Note: it may also need to be used with config.compile_threads = 1 disallow_failing_autotune_kernels_TESTING_ONLY = False + # specify number of splits to autotune on for decompose_k. 0 disables decompose_k + num_decompose_k_splits = int( + os.environ.get("TORCHINDUCTOR_NUM_DECOMPOSE_K_SPLITS", "10") + ) + + # specify minimum ratio of K to M AND N in order to autotune on decompose_k. 0 enables + # it as an autotuning choice for all matmuls + decompose_k_threshold = int( + os.environ.get("TORCHINDUCTOR_DECOMPOSE_K_THRESHOLD", "32") + ) + class aot_inductor: """ @@ -1497,11 +1517,11 @@ class cuda: # Path to the CUTLASS repo root directory. # The default path only works under PyTorch local development environment. - cutlass_dir = os.environ.get( - "TORCHINDUCTOR_CUTLASS_DIR", - os.path.abspath( - os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/") - ), + cutlass_dir = os.path.realpath( + os.environ.get( + "TORCHINDUCTOR_CUTLASS_DIR", + os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/"), + ) ) # Configures the maximum number of CUTLASS configs to profile in max_autotune. @@ -1598,6 +1618,9 @@ class cuda: # Use this to overwrite and handle cache pollution binary_remote_cache_force_write: bool = False + # Enable caching codegen of cuda templates. + enable_caching_codegen: bool = True + class rocm: # Offload arch list for device code compilation, e.g. ["gfx90a", "gfx942"]. @@ -1763,8 +1786,11 @@ class trace: log_autotuning_results = os.environ.get("LOG_AUTOTUNE_RESULTS", "0") == "1" - # Save mapping info from inductor generated triton kernel to post_grad fx nodes - log_inductor_triton_kernel_to_post_grad_node_info: bool = True + # Save mapping info from inductor generated triton kernel to post_grad fx nodes to pre_grad fx nodes + provenance_tracking = ( + os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" + or os.environ.get("INDUCTOR_PROVENANCE", "0") == "1" + ) _save_config_ignore: list[str] = [ diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 9d0c7c76dfea..47820d3d7725 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -1074,6 +1074,19 @@ def _get_openmp_args( return cflags, ldflags, include_dir_paths, lib_dir_paths, libs, passthrough_args +def _get_libstdcxx_args() -> tuple[list[str], list[str]]: + """ + For fbcode cpu case, we should link stdc++ instead assuming the binary where dlopen is executed is built with dynamic stdc++. + """ + lib_dir_paths: list[str] = [] + libs: list[str] = [] + if config.is_fbcode(): + lib_dir_paths = [sysconfig.get_config_var("LIBDIR")] + libs.append("stdc++") + + return lib_dir_paths, libs + + def get_mmap_self_macro(use_mmap_weights: bool) -> list[str]: macros = [] if use_mmap_weights: @@ -1089,6 +1102,15 @@ def get_cpp_torch_options( use_relative_path: bool, use_mmap_weights: bool, ) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]: + """ + This function is used to get the build args of torch related build options. + 1. Torch include_directories, libraries, libraries_directories. + 2. Python include_directories, libraries, libraries_directories. + 3. OpenMP related. + 4. Torch MACROs. + 5. MISC + 6. Return the build args + """ definitions: list[str] = [] include_dirs: list[str] = [] cflags: list[str] = [] @@ -1269,6 +1291,13 @@ def get_cpp_torch_device_options( aot_mode: bool = False, compile_only: bool = False, ) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]: + """ + This function is used to get the build args of device related build options. + 1. Device include_directories, libraries, libraries_directories. + 2. Device MACROs. + 3. MISC + 4. Return the build args + """ definitions: list[str] = [] include_dirs: list[str] = [] cflags: list[str] = [] @@ -1288,6 +1317,8 @@ def get_cpp_torch_device_options( include_dirs = cpp_extension.include_paths(device_type) libraries_dirs = cpp_extension.library_paths(device_type) + if not config.is_fbcode(): + libraries += ["c10"] if device_type == "cuda": definitions.append(" USE_ROCM" if torch.version.hip else " USE_CUDA") @@ -1327,6 +1358,14 @@ def get_cpp_torch_device_options( # Only add link args, when compile_only is false. passthrough_args = ["-Wl,-Bstatic -lcudart_static -Wl,-Bdynamic"] + if device_type == "cpu": + ( + stdcxx_lib_dir_paths, + stdcxx_libs, + ) = _get_libstdcxx_args() + libraries_dirs += stdcxx_lib_dir_paths + libraries += stdcxx_libs + if config.aot_inductor.custom_op_libs: libraries += config.aot_inductor.custom_op_libs diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index bdc201803fb6..3b3dea909cd2 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -90,6 +90,7 @@ from torch._guards import CompileId from torch._inductor.utils import InputType + from torch.cuda import _POOL_HANDLE from torch.types import _bool StorageWeakRefPointer = int @@ -817,7 +818,7 @@ def __init__( id: GraphID, parent: Optional[CUDAGraphNode], inputs: list[InputType], - cuda_graphs_pool: tuple[int, int], + cuda_graphs_pool: _POOL_HANDLE, device_index: int, stack_traces: Optional[StackTraces], stream: torch.cuda.Stream, @@ -1228,6 +1229,7 @@ def all_outputs_are_dead(self) -> bool: def _record(self, model: ModelType, inputs: list[InputType]) -> OutputType: "Record the model" + assert self.graph is not None def static_input_iter() -> Generator[torch.Tensor, None, None]: for i in self.wrapped_function.static_input_idxs: @@ -1310,13 +1312,11 @@ def _add_first_outputs( self.output_storage_alias.append(UnaliasedStorage) continue - ( - torch._check( - o.is_cuda or o.untyped_storage().data_ptr() == 0, - lambda: ( - "Expected all cuda outputs in cuda graph recording. Non cuda output " - f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}" - ), + torch._check( + o.is_cuda or o.untyped_storage().data_ptr() == 0, + lambda: ( + "Expected all cuda outputs in cuda graph recording. Non cuda output " + f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}" ), ) diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index d3bc89a3d412..23b26765df2b 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -13,7 +13,7 @@ import pstats import shutil import traceback -from collections.abc import Iterator +from collections.abc import Iterator, Sequence from typing import Any, Callable, IO, Optional, Union from unittest.mock import patch @@ -31,6 +31,7 @@ from torch.utils._pytree import tree_map from . import config, ir # noqa: F811, this is needed +from .ir import ExternKernelOut from .scheduler import ( BaseSchedulerNode, FusedSchedulerNode, @@ -313,15 +314,45 @@ def enable_aot_logging() -> Iterator[None]: # They are not stored in DebugContext because they are not set in # _inductor_triton_kernel_to_post_grad_node_info's Debug Context _inductor_post_to_pre_grad_nodes: dict[str, Any] = {} +_inductor_triton_kernel_to_post_grad_node_info: dict[str, Any] = {} _pre_grad_graph_id: Optional[int] = None +_inductor_pre_grad_node_stack_trace: dict[str, str] = {} + + +@contextlib.contextmanager +def reset_provenance_globals() -> Iterator[None]: + """Context manager that resets provenance tracking globals upon entering + and restores their original values when exiting.""" + global _pre_grad_graph_id + global _inductor_post_to_pre_grad_nodes + global _inductor_triton_kernel_to_post_grad_node_info + + # Store original values + original_pre_grad_graph_id = _pre_grad_graph_id + original_post_to_pre_grad_nodes = _inductor_post_to_pre_grad_nodes.copy() + original_triton_kernel_to_post_grad_node_info = ( + _inductor_triton_kernel_to_post_grad_node_info.copy() + ) + + # Reset to default values + _pre_grad_graph_id = -1 + _inductor_post_to_pre_grad_nodes = {} + _inductor_triton_kernel_to_post_grad_node_info = {} + + try: + yield + finally: + # Restore original values + _pre_grad_graph_id = original_pre_grad_graph_id + _inductor_post_to_pre_grad_nodes = original_post_to_pre_grad_nodes + _inductor_triton_kernel_to_post_grad_node_info = ( + original_triton_kernel_to_post_grad_node_info + ) class DebugContext: _counter = itertools.count() - # Used for provenance tracking - _inductor_triton_kernel_to_post_grad_node_info: dict[str, list[str]] = {} - @staticmethod def create_debug_dir(folder_name: str) -> Optional[str]: debug_dir = config.trace.debug_dir or get_debug_dir() @@ -557,25 +588,6 @@ def draw_orig_fx_graph( def output_code(self, filename: str, extension: str = "py") -> None: shutil.copy(filename, self.filename(f"output_code.{extension}")) - def log_inductor_triton_kernel_to_post_grad_node_info( - self, filename: str = "inductor_generated_kernel_to_post_grad_nodes.json" - ) -> tuple[dict[str, list[str]], dict[str, Any]]: - debug_info = {} - with self.fopen(filename, "w") as fd: - log.info("Writing provenance tracing debugging info to %s", fd.name) - debug_info = DebugContext._inductor_triton_kernel_to_post_grad_node_info - json.dump(debug_info, fd) - node_mapping = {} - if _pre_grad_graph_id: - with self.fopen( - "inductor_provenance_tracking_node_mappings.json", "w" - ) as fd: - node_mapping = create_node_mapping( - _pre_grad_graph_id, _inductor_post_to_pre_grad_nodes, debug_info - ) - json.dump(node_mapping, fd) - return debug_info, node_mapping - def log_autotuning_results( self, name: str, @@ -690,23 +702,18 @@ class TensorMetadataHolder: save_args_cnt = itertools.count() -def create_node_mapping( - pre_grad_graph_id: int, +def create_mapping_pre_post_grad_nodes( + pre_grad_graph_id: Optional[int], post_to_pre_grad_nodes_json: dict[str, Any], - triton_kernel_to_post_grad_json: dict[str, Any], ) -> dict[str, dict[str, Any]]: - """Create bidirectional mappings between: - - - pre_grad graph nodes and post_grad graph code nodes, and vice versa - - triton kernel name and post_grad graph code nodes, and vice versa """ - + Create bidirectional mappings between pre_grad graph nodes + and post_grad graph code nodes, and vice versa. + """ # return a dummy dict if there's any error empty_return: dict[str, dict[str, Any]] = { "preToPost": {}, "postToPre": {}, - "cppCodeToPost": {}, - "postToCppCode": {}, } log.info("Creating node mappings for provenance tracking") @@ -715,12 +722,6 @@ def create_node_mapping( log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict") return empty_return - if not isinstance(triton_kernel_to_post_grad_json, dict): - log.error( - "Provenance tacking error: triton_kernel_to_post_grad_json is not a dict" - ) - return empty_return - if not isinstance(pre_grad_graph_id, int): log.error("Provenance tacking error: pre_grad_graph_id is not an int") return empty_return @@ -728,17 +729,7 @@ def create_node_mapping( pre_to_post: dict[str, Any] = collections.defaultdict(OrderedSet) post_to_pre: dict[str, Any] = collections.defaultdict(OrderedSet) - post_to_cpp_code: dict[str, Any] = collections.defaultdict(OrderedSet) - try: - for outer_key, node_array in triton_kernel_to_post_grad_json.items(): - if not isinstance(node_array, list): - log.error( - "Provenance tacking error: triton_kernel_to_post_grad_json value is not a list" - ) - return empty_return - for curr_node in node_array: - post_to_cpp_code[curr_node].add(outer_key) def check_format(node: dict[str, Any]) -> bool: if not isinstance(node, dict): @@ -788,10 +779,61 @@ def convert_sets_to_lists(d: dict[str, Any]) -> None: # convert to list because set is not JSON serializable convert_sets_to_lists(pre_to_post) convert_sets_to_lists(post_to_pre) - convert_sets_to_lists(post_to_cpp_code) return { "preToPost": pre_to_post, "postToPre": post_to_pre, + } + except Exception as e: + # Since this is just logging code, it should never interfere with regular + # program execution, so we use this try-except to guard against any error + log.error("Unexpected error in create_node_mapping: %s", e) + log.error("post_to_pre_grad_nodes_json: %s", post_to_pre_grad_nodes_json) + log.error("pre_grad_graph_id: %s", pre_grad_graph_id) + log.error(traceback.format_exc()) + return empty_return + + +def create_node_mapping_kernel_to_post_grad( + triton_kernel_to_post_grad_json: dict[str, Any], +) -> dict[str, dict[str, Any]]: + """Create bidirectional mappings between triton kernel name and post_grad + graph code nodes, and vice versa. + """ + + # return a dummy dict if there's any error + empty_return: dict[str, dict[str, Any]] = { + "cppCodeToPost": {}, + "postToCppCode": {}, + } + + log.info("Creating node mappings for provenance tracking") + + if not isinstance(triton_kernel_to_post_grad_json, dict): + log.error( + "Provenance tacking error: triton_kernel_to_post_grad_json is not a dict" + ) + return empty_return + + post_to_cpp_code: dict[str, Any] = collections.defaultdict(OrderedSet) + + try: + for outer_key, node_array in triton_kernel_to_post_grad_json.items(): + if not isinstance(node_array, list): + log.error( + "Provenance tacking error: triton_kernel_to_post_grad_json value is not a list" + ) + return empty_return + for curr_node in node_array: + post_to_cpp_code[curr_node].add(outer_key) + + def convert_sets_to_lists(d: dict[str, Any]) -> None: + for key in d: + d[key] = list(d[key]) + d = dict(d) + + # convert to list because set is not JSON serializable + convert_sets_to_lists(post_to_cpp_code) + return { "cppCodeToPost": triton_kernel_to_post_grad_json, "postToCppCode": post_to_cpp_code, } @@ -799,15 +841,83 @@ def convert_sets_to_lists(d: dict[str, Any]) -> None: # Since this is just logging code, it should never interfere with regular # program execution, so we use this try-except to guard against any error log.error("Unexpected error in create_node_mapping: %s", e) - log.error("post_to_pre_grad_nodes_json: %s", post_to_pre_grad_nodes_json) log.error( "triton_kernel_to_post_grad_json: %s", triton_kernel_to_post_grad_json ) - log.error("pre_grad_graph_id: %s", pre_grad_graph_id) log.error(traceback.format_exc()) return empty_return +def dump_inductor_provenance_info( + filename: str = "inductor_generated_kernel_to_post_grad_nodes.json", +) -> dict[str, Any]: + global _pre_grad_graph_id + global _inductor_post_to_pre_grad_nodes + global _inductor_triton_kernel_to_post_grad_node_info + if config.trace.enabled: + with V.debug.fopen(filename, "w") as fd: + log.info("Writing provenance tracing debugging info to %s", fd.name) + json.dump(_inductor_triton_kernel_to_post_grad_node_info, fd) + node_mapping = {} + if _pre_grad_graph_id: + node_mapping_kernel = create_node_mapping_kernel_to_post_grad( + _inductor_triton_kernel_to_post_grad_node_info + ) + node_mapping = { + **_inductor_post_to_pre_grad_nodes, + **node_mapping_kernel, + } + if config.trace.enabled: + with V.debug.fopen( + "inductor_provenance_tracking_node_mappings.json", "w" + ) as fd: + json.dump(node_mapping, fd) + return node_mapping + + +def set_kernel_post_grad_provenance_tracing( + node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernelOut], + kernel_name: str, + is_extern: bool = False, +) -> None: + from .codegen.simd_kernel_features import DisableReduction, EnableReduction + + global _inductor_triton_kernel_to_post_grad_node_info + if is_extern: + assert isinstance(node_schedule, ExternKernelOut) + curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault( + kernel_name, [] + ) + # 'origins' on IR nodes gives what FX IR nodes contributed to any given fused kernel. + # "origin_node" is more precise and says that the contents of this node corresponds + # EXACTLY to the output of a particular FX node, but it's not always available + if node_schedule.origin_node: + origin_node_name = node_schedule.origin_node.name + if origin_node_name not in curr_node_info: + curr_node_info.append(origin_node_name) + else: + curr_node_info.extend( + origin.name + for origin in node_schedule.origins + if origin.name not in curr_node_info + ) + else: + assert isinstance(node_schedule, list) + for snode in node_schedule: + if snode not in (EnableReduction, DisableReduction): + if snode.node is not None: + curr_node_info = ( + _inductor_triton_kernel_to_post_grad_node_info.setdefault( + kernel_name, [] + ) + ) + curr_node_info.extend( + origin.name + for origin in snode.node.origins + if origin.name not in curr_node_info + ) + + def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None: """ This function is used to save arguments for a compile_fx_inner function call diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 08c3abc9f23f..b81e6edbb54e 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -701,7 +701,7 @@ def randint( def linear_dynamic_fp16_unpacked_weight( input: torch.Tensor, weight: torch.Tensor, - bias: torch.Tensor, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: packed_weight = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(weight) return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight( diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 9de52061c648..8a374f5bab35 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -342,6 +342,12 @@ class WeakDep(Dep): name: str # Buffer that is doing the mutation mutating_buf: str + # WeakDep's are also used to add dependencies to prevent some specific reordering, + # E.g. collectives global ordering. + # But if other pass guarantees proper ordering by its logic, + # This additional "fake" deps will be holding optimizations. + # This flag is used to identify those additional deps. + is_fake: bool = False @property def index(self) -> sympy.Expr: @@ -352,7 +358,7 @@ def get_numel(self) -> sympy.Expr: def rename(self, renames: dict[str, str]) -> "WeakDep": if self.name in renames: - return WeakDep(renames[self.name], self.mutating_buf) + return WeakDep(renames[self.name], self.mutating_buf, self.is_fake) return self def numbytes_hint(self) -> int: diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index 8f5bb5ffd324..1794ce3a2a29 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -1,6 +1,7 @@ import logging import math import operator +from collections import defaultdict from typing import Any, Callable, Optional, Union import torch @@ -77,13 +78,9 @@ def bucket_all_gather_by_mb( ) -> list[list[torch.fx.Node]]: """ Identifies all all_gather nodes and groups them into buckets based on size limit `all_gather_bucket_cap_mb_callback`. - - Returns a list of buckets, where each bucket is a list of all_gather nodes. """ - node_list = gm.graph.nodes - # Prerequisite: Check if there is any all_gather node found_all_gather = False for node in node_list: @@ -92,48 +89,53 @@ def bucket_all_gather_by_mb( break if not found_all_gather: return [] - - ag_nodes: list[torch.fx.Node] = [] - + group_name_ag_nodes: dict[tuple[str, torch.dtype], list[torch.fx.Node]] = ( # type: ignore[name-defined] + defaultdict(list) + ) # Step 1: Find all all_gather nodes for node in node_list: if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]): if (filter_wait_node is None) or filter_wait_node(node): ag_node = node.args[0] - ag_nodes.append(ag_node) - + _, group_size, group_name = ag_node.args + dtype = ag_node.meta["val"].dtype + assert isinstance(group_name, str) + group_name_ag_nodes[(group_name, dtype)].append(ag_node) # Step 2: Put all_gather nodes into buckets ag_buckets: list[list[torch.fx.Node]] = [] - cur_bucket: list[torch.fx.Node] = [] - cur_bucket_size_bytes: int = 0 - cur_bucket_id: int = 0 - # Convert MiB to bytes - all_gather_bucket_size_bytes = int( - all_gather_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024 - ) - for ag_node in ag_nodes: - assert is_all_gather_into_tensor(ag_node) - assert "val" in ag_node.meta - ag_output_size_bytes = ( - ag_node.meta["val"].numel() - * torch.finfo(ag_node.meta["val"].dtype).bits - // 8 + for (group_name, dtype), ag_nodes in group_name_ag_nodes.items(): + cur_bucket: list[torch.fx.Node] = [] + cur_bucket_recursive_users: OrderedSet[torch.fx.Node] = OrderedSet() + cur_bucket_size_bytes: int = 0 + cur_bucket_id: int = 0 + all_gather_bucket_size_bytes = int( + all_gather_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024 ) - if ( - cur_bucket_size_bytes + ag_output_size_bytes > all_gather_bucket_size_bytes - and cur_bucket - ): - # Current bucket is full, create new bucket + for ag_node in ag_nodes: + assert is_all_gather_into_tensor(ag_node) + if ag_node in cur_bucket_recursive_users: + # We can not bucket successors with the node + continue + assert "val" in ag_node.meta + ag_n_val = ag_node.meta["val"] + ag_output_size_bytes = ag_n_val.numel() * ag_n_val.element_size() + if ( + cur_bucket_size_bytes + ag_output_size_bytes + > all_gather_bucket_size_bytes + and cur_bucket + ): + # Current bucket is full, create new bucket + if len(cur_bucket) > 1: + ag_buckets.append(cur_bucket) + cur_bucket = [] + cur_bucket_size_bytes = 0 + cur_bucket_id += 1 + cur_bucket_size_bytes += ag_output_size_bytes + cur_bucket.append(ag_node) + find_recursive_users_of_fx_node(ag_node, cur_bucket_recursive_users) + if len(cur_bucket) > 1: + # add remaining nodes in the last bucket ag_buckets.append(cur_bucket) - cur_bucket = [] - cur_bucket_size_bytes = 0 - cur_bucket_id += 1 - cur_bucket_size_bytes += ag_output_size_bytes - cur_bucket.append(ag_node) - if cur_bucket: - # add remaining nodes in the last bucket - ag_buckets.append(cur_bucket) - return ag_buckets @@ -143,13 +145,9 @@ def bucket_reduce_scatter_by_mb( ) -> list[list[torch.fx.Node]]: """ Identifies all reduce_scatter nodes and groups them into buckets based on size limit `reduce_scatter_bucket_cap_mb_callback`. - - Returns a list of buckets, where each bucket is a list of reduce_scatter nodes. """ - node_list = list(gm.graph.nodes) - # Prerequisite: Check if there is any reduce_scatter node found_reduce_scatter = False for node in node_list: @@ -158,64 +156,71 @@ def bucket_reduce_scatter_by_mb( break if not found_reduce_scatter: return [] - - rs_nodes: list[torch.fx.Node] = [] - + group_name_rs_nodes: dict[tuple[str, str, torch.dtype], list[torch.fx.Node]] = ( # type: ignore[name-defined] + defaultdict(list) + ) # Step 1: Find all reduce_scatter nodes for node in node_list: if is_wait_tensor(node) and is_reduce_scatter_tensor(node.args[0]): rs_node = node.args[0] - rs_nodes.append(rs_node) - + _, reduce_op, group_size, group_name = rs_node.args + dtype = rs_node.meta["val"].dtype + assert isinstance(group_name, str) + assert isinstance(reduce_op, str) + group_name_rs_nodes[(group_name, reduce_op, dtype)].append(rs_node) # Step 2: Put reduce_scatter nodes into buckets rs_buckets: list[list[torch.fx.Node]] = [] - cur_bucket: list[torch.fx.Node] = [] - cur_bucket_size_bytes: int = 0 - cur_bucket_id: int = 0 - # Convert MiB to bytes - reduce_scatter_bucket_size_bytes = int( - reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024 - ) - for rs_node in rs_nodes: - assert is_reduce_scatter_tensor(rs_node) - rs_input = rs_node.args[0] - assert "val" in rs_input.meta # type: ignore[union-attr] - rs_input_size_bytes = ( - rs_input.meta["val"].numel() # type: ignore[union-attr] - * torch.finfo(rs_input.meta["val"].dtype).bits # type: ignore[union-attr] - // 8 + for (group_name, reduce_op, dtype), rs_nodes in group_name_rs_nodes.items(): + cur_bucket: list[torch.fx.Node] = [] + cur_bucket_recursive_users: OrderedSet[torch.fx.Node] = OrderedSet() + cur_bucket_size_bytes: int = 0 + cur_bucket_id: int = 0 + # Convert MiB to bytes + reduce_scatter_bucket_size_bytes = int( + reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024 ) - if ( - cur_bucket_size_bytes + rs_input_size_bytes - > reduce_scatter_bucket_size_bytes - and cur_bucket - ): - # Current bucket is full, create new bucket - total_size = cur_bucket_size_bytes + rs_input_size_bytes + for rs_node in rs_nodes: + assert is_reduce_scatter_tensor(rs_node) + if rs_node in cur_bucket_recursive_users: + # We can not bucket successors with the node + continue + rs_input = rs_node.args[0] + assert "val" in rs_input.meta # type: ignore[union-attr] + rs_in_val = rs_input.meta["val"] # type: ignore[union-attr] + rs_input_size_bytes = rs_in_val.numel() * rs_in_val.element_size() + if ( + cur_bucket_size_bytes + rs_input_size_bytes + > reduce_scatter_bucket_size_bytes + and cur_bucket + ): + # Current bucket is full, create new bucket + total_size = cur_bucket_size_bytes + rs_input_size_bytes + logger.info( + f"Reduce scatter bucket {cur_bucket_id} full: " # noqa: G004 + f"total_size = {total_size} = cur_bucket_size_bytes + rs_input_size_bytes = " + f"{cur_bucket_size_bytes} + {rs_input_size_bytes}," + f"bucket_cap = {reduce_scatter_bucket_size_bytes}" + ) + if len(cur_bucket) > 1: + rs_buckets.append(cur_bucket) + cur_bucket = [] + cur_bucket_size_bytes = 0 + cur_bucket_id += 1 + reduce_scatter_bucket_size_bytes = int( + reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024 + ) + cur_bucket_size_bytes += rs_input_size_bytes + cur_bucket.append(rs_node) + find_recursive_users_of_fx_node(rs_node, cur_bucket_recursive_users) + if cur_bucket: + # add remaining nodes in the last bucket logger.info( - f"Reduce scatter bucket {cur_bucket_id} full: " # noqa: G004 - f"total_size = {total_size} = cur_bucket_size_bytes + rs_input_size_bytes = " - f"{cur_bucket_size_bytes} + {rs_input_size_bytes}," + f"Reduce scatter last bucket {cur_bucket_id}: " # noqa: G004 + f"total_size = {cur_bucket_size_bytes}, " f"bucket_cap = {reduce_scatter_bucket_size_bytes}" ) - rs_buckets.append(cur_bucket) - cur_bucket = [] - cur_bucket_size_bytes = 0 - cur_bucket_id += 1 - reduce_scatter_bucket_size_bytes = int( - reduce_scatter_bucket_cap_mb_callback(cur_bucket_id) * 1024 * 1024 - ) - cur_bucket_size_bytes += rs_input_size_bytes - cur_bucket.append(rs_node) - if cur_bucket: - # add remaining nodes in the last bucket - logger.info( - f"Reduce scatter last bucket {cur_bucket_id}: " # noqa: G004 - f"total_size = {cur_bucket_size_bytes}, " - f"bucket_cap = {reduce_scatter_bucket_size_bytes}" - ) - rs_buckets.append(cur_bucket) - + if len(cur_bucket) > 1: + rs_buckets.append(cur_bucket) return rs_buckets @@ -260,6 +265,18 @@ def env_lookup( # type: ignore[no-untyped-def] return env[x] +def _rank_idx_dict(group_name: str) -> dict[int, int]: + from torch.distributed.distributed_c10d import ( + _resolve_process_group, + get_process_group_ranks, + ) + + pg = _resolve_process_group(group_name) + ranks = get_process_group_ranks(pg) + rank_idx_dict: dict[int, int] = {rank: idx for idx, rank in enumerate(ranks)} + return rank_idx_dict + + def merge_all_gather( gm: torch.fx.GraphModule, ag_buckets: list[list[torch.fx.Node]] ) -> None: @@ -297,15 +314,13 @@ def merge_all_gather( bucket_id_is_scheduled = {} cast_bucket_id_is_scheduled = {} _, group_size, group_name = next(iter(ag_node_to_wait_node.keys())).args + + group_name_to_rank_idx_dict: dict[str, dict[int, int]] = {} + for bucket_id, ag_bucket in enumerate(ag_buckets): ag_input_nodes = [] wait_nodes = [] for ag_node in ag_bucket: - assert ( - ag_node in ag_node_to_wait_node - and ag_node.args[1] == group_size - and ag_node.args[2] == group_name - ) ag_input_nodes.append(ag_node.args[0]) wait_nodes.append(ag_node_to_wait_node[ag_node]) bucket_id_to_bucketed_op_info[bucket_id] = ( @@ -314,6 +329,8 @@ def merge_all_gather( group_name, wait_nodes, ) + if group_name not in group_name_to_rank_idx_dict: + group_name_to_rank_idx_dict[group_name] = _rank_idx_dict(group_name) # type: ignore[arg-type, index] ag_wait_nodes = list(ag_node_to_wait_node.values()) ag_and_wait_nodes = OrderedSet(ag_nodes + ag_wait_nodes) @@ -334,9 +351,6 @@ def merge_all_gather( ag_input_nodes, group_size, group_name, orig_wait_nodes = ( bucket_id_to_bucketed_op_info[bucket_id] ) - # device = ag_input_nodes[0].meta["val"].device - # rank = device.index - # dtype = ag_input_nodes[0].meta["val"].dtype if all( n.op == "call_function" # type: ignore[union-attr] and n.target == torch.ops.prims.convert_element_type.default # type: ignore[union-attr] @@ -398,6 +412,7 @@ def merge_all_gather( ag_input_nodes, group_size, group_name, orig_wait_nodes = ( bucket_id_to_bucketed_op_info[bucket_id] ) + rank_idx_dict = group_name_to_rank_idx_dict[group_name] # type: ignore[index] device = ag_input_nodes[0].meta["val"].device # type: ignore[union-attr] rank = device.index dtype = ag_input_nodes[0].meta["val"].dtype # type: ignore[union-attr] @@ -468,7 +483,7 @@ def merge_all_gather( all_gather_output, inp_split_sizes, all_gather_input_numel, - rank, + rank_idx_dict[rank], ), {}, ) @@ -585,6 +600,7 @@ def merge_reduce_scatter( # Prepare bucketed operation info bucket_id_to_bucketed_op_info = {} bucket_id_is_scheduled = {} + group_name_to_rank_idx_dict: dict[str, dict[int, int]] = {} for bucket_id, rs_bucket in enumerate(rs_buckets): _, reduce_op, group_size, group_name = next( iter(rs_node_to_wait_node.keys()) @@ -612,6 +628,8 @@ def merge_reduce_scatter( wait_nodes, wait_node_recursive_users, ) + if group_name not in group_name_to_rank_idx_dict: + group_name_to_rank_idx_dict[group_name] = _rank_idx_dict(group_name) # type: ignore[arg-type, index] new_graph: torch.fx.Graph = torch.fx.Graph() env: dict[torch.fx.Node, torch.fx.Node] = {} @@ -624,155 +642,154 @@ def merge_reduce_scatter( elif node in rs_node_to_wait_node: assert node in rs_node_to_bucket_id bucket_id = rs_node_to_bucket_id[node] - if ( + if not ( bucket_id not in bucket_id_is_scheduled and rs_buckets[bucket_id][-1] == node ): - # If we are at the last node in the bucket, we can start to schedule the bucketed reduce_scatter node - ( - rs_input_nodes, - reduce_op, - group_size, - group_name, - orig_wait_nodes, - orig_wait_node_recursive_users, - ) = bucket_id_to_bucketed_op_info[bucket_id] - # parents of rs have been scheduled, so we can directly use the env - unsharded_grads = [env[x] for x in rs_input_nodes] # type: ignore[index] - reduce_dtype = unsharded_grads[0].meta["val"].dtype - # Only float32 and bfloat16 are supported for now. - # To support fp16, please see FSDP2 `_get_gradient_divide_factors`. - assert reduce_dtype in ( - torch.float32, - torch.bfloat16, - ), f"reduce_dtype {reduce_dtype} is not supported" - assert all( - grad.meta["val"].dtype == reduce_dtype for grad in unsharded_grads - ) - device = unsharded_grads[0].meta["val"].device - rank = device.index - shard_dim = 0 + continue - def _get_dim0_padded_size( - tensor_size: torch.Size, dim0_factor: int - ) -> torch.Size: - padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor - return torch.Size([padded_dim0]) + tensor_size[1:] + # If we are at the last node in the bucket, we can start to schedule the bucketed reduce_scatter node + ( + rs_input_nodes, + reduce_op, + group_size, + group_name, + orig_wait_nodes, + orig_wait_node_recursive_users, + ) = bucket_id_to_bucketed_op_info[bucket_id] + rank_idx_dict = group_name_to_rank_idx_dict[group_name] # type: ignore[index] + # parents of rs have been scheduled, so we can directly use the env + unsharded_grads = [env[x] for x in rs_input_nodes] # type: ignore[index] + reduce_dtype = unsharded_grads[0].meta["val"].dtype + # Only float32 and bfloat16 are supported for now. + # To support fp16, please see FSDP2 `_get_gradient_divide_factors`. + assert reduce_dtype in ( + torch.float32, # type: ignore[attr-defined] + torch.bfloat16, # type: ignore[attr-defined] + ), f"reduce_dtype {reduce_dtype} is not supported" + assert all( + grad.meta["val"].dtype == reduce_dtype for grad in unsharded_grads + ) + device = unsharded_grads[0].meta["val"].device + rank = device.index + rank_idx = rank_idx_dict[rank] + shard_dim = 0 + + def _get_dim0_padded_size( + tensor_size: torch.Size, + dim0_factor: int, # type: ignore[name-defined] + ) -> torch.Size: # type: ignore[name-defined] + padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor # type: ignore[attr-defined] + return torch.Size([padded_dim0]) + tensor_size[1:] + + padded_unsharded_sizes = tuple( + _get_dim0_padded_size(grad.meta["val"].size(), group_size) # type: ignore[arg-type] + for grad in unsharded_grads + ) + reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes) + + """ + NOTE: the relationship between the next few nodes is tricky: + - reduce_scatter_input_reshaped is a view of reduce_scatter_input + (same storage, same # elems, different shape). + - chunk_cat writes into reduce_scatter_input_reshaped, + which indirectly writes into reduce_scatter_input + (since they share the same storage). + - reduce_scatter_tensor reads from reduce_scatter_input. + """ + reduce_scatter_input = new_graph_call_function( + new_graph, + torch.ops.aten.empty.memory_format, + ([reduce_scatter_input_numel],), + { + "dtype": reduce_dtype, + "device": device, + "pin_memory": False, + }, + ) + reduce_scatter_input_reshaped = new_graph_call_function( + new_graph, + torch.ops.aten.reshape.default, + (reduce_scatter_input, [group_size, -1]), + {}, + ) + new_graph_call_function( + new_graph, + torch.ops.fsdp.chunk_cat.default, + (unsharded_grads,), + { + "dim": 0, + "num_chunks": group_size, + "out": reduce_scatter_input_reshaped, + }, + ) + reduce_scatter_tensor = new_graph_call_function( + new_graph, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + (reduce_scatter_input, reduce_op, group_size, group_name), + {}, + ) - padded_unsharded_sizes = tuple( - _get_dim0_padded_size(grad.meta["val"].size(), group_size) # type: ignore[arg-type] - for grad in unsharded_grads - ) - reduce_scatter_input_numel = sum( - s.numel() for s in padded_unsharded_sizes - ) + wait_tensor = new_graph_call_function( + new_graph, + torch.ops._c10d_functional.wait_tensor.default, + (reduce_scatter_tensor,), + {}, + ) - """ - NOTE: the relationship between the next few nodes is tricky: - - reduce_scatter_input_reshaped is a view of reduce_scatter_input - (same storage, same # elems, different shape). - - chunk_cat writes into reduce_scatter_input_reshaped, - which indirectly writes into reduce_scatter_input - (since they share the same storage). - - reduce_scatter_tensor reads from reduce_scatter_input. - """ - reduce_scatter_input = new_graph_call_function( - new_graph, - torch.ops.aten.empty.memory_format, - ([reduce_scatter_input_numel],), - { - "dtype": reduce_dtype, - "device": device, - "pin_memory": False, - }, + def _chunk_with_empty( + tensor: torch.Tensor, num_chunks: int, dim: int + ) -> list[torch.Tensor]: + chunks = list(torch.chunk(tensor, num_chunks, dim=dim)) + while len(chunks) < num_chunks: + chunks.append(chunks[0].new_empty(0)) + return chunks + + reduce_output = wait_tensor + # View out and accumulate sharded gradients + new_sharded_grads = [] + flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1] + for padded_unsharded_size, unsharded_grad in zip( + padded_unsharded_sizes, unsharded_grads + ): + # NOTE: we only care about the shape of tensors in `chunks`, so using meta tensor here + chunks = _chunk_with_empty( + torch.empty_like(unsharded_grad.meta["val"], device="meta"), + group_size, # type: ignore[arg-type] + dim=shard_dim, ) - reduce_scatter_input_reshaped = new_graph_call_function( - new_graph, - torch.ops.aten.reshape.default, - (reduce_scatter_input, [group_size, -1]), - {}, + sharded_param = chunks[rank_idx] + sharded_size = sharded_param.size() + contiguous_sharded_stride = ( + torch._prims_common.make_contiguous_strides_for(sharded_size) ) - new_graph_call_function( + # Assume even sharding for Shard(i), i > 0; otherwise would require + # copy-out for contiguous strides + new_sharded_grad = new_graph_call_function( new_graph, - torch.ops.fsdp.chunk_cat.default, - (unsharded_grads,), + torch.ops.aten.as_strided.default, + (reduce_output,), { - "dim": 0, - "num_chunks": group_size, - "out": reduce_scatter_input_reshaped, + "size": sharded_size, + "stride": contiguous_sharded_stride, + "storage_offset": flat_grad_offset, }, ) - reduce_scatter_tensor = new_graph_call_function( - new_graph, - torch.ops._c10d_functional.reduce_scatter_tensor.default, - (reduce_scatter_input, reduce_op, group_size, group_name), - {}, - ) - - wait_tensor = new_graph_call_function( - new_graph, - torch.ops._c10d_functional.wait_tensor.default, - (reduce_scatter_tensor,), - {}, - ) - - def _chunk_with_empty( - tensor: torch.Tensor, num_chunks: int, dim: int - ) -> list[torch.Tensor]: - chunks = list(torch.chunk(tensor, num_chunks, dim=dim)) - while len(chunks) < num_chunks: - chunks.append(chunks[0].new_empty(0)) - return chunks - - reduce_output = wait_tensor - # View out and accumulate sharded gradients - new_sharded_grads = [] - flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1] - for padded_unsharded_size, unsharded_grad in zip( - padded_unsharded_sizes, unsharded_grads - ): - # NOTE: we only care about the shape of tensors in `chunks`, so using meta tensor here - chunks = _chunk_with_empty( - torch.empty_like(unsharded_grad.meta["val"], device="meta"), - group_size, # type: ignore[arg-type] - dim=shard_dim, - ) - sharded_param = chunks[rank] - sharded_size = sharded_param.size() - contiguous_sharded_stride = ( - torch._prims_common.make_contiguous_strides_for(sharded_size) - ) - # Assume even sharding for Shard(i), i > 0; otherwise would require - # copy-out for contiguous strides - new_sharded_grad = new_graph_call_function( - new_graph, - torch.ops.aten.as_strided.default, - (reduce_output,), - { - "size": sharded_size, - "stride": contiguous_sharded_stride, - "storage_offset": flat_grad_offset, - }, - ) - new_sharded_grads.append(new_sharded_grad) - padded_sharded_numel = padded_unsharded_size.numel() // group_size # type: ignore[operator] - flat_grad_offset += padded_sharded_numel # type: ignore[assignment] - assert len(orig_wait_nodes) == len(new_sharded_grads) - assert len(orig_wait_nodes) > 0 - for new_sharded_grad, orig_wait_node in zip( - new_sharded_grads, orig_wait_nodes - ): - env[orig_wait_node] = new_sharded_grad # noqa: PERF403 - for user in sorted( - orig_wait_node_recursive_users, key=lambda x: order[x] - ): - # We skip output node here, because output node will be inserted (later) - # as the last node in the new graph. - if user.op != "output": - node_copy( - env, new_graph, user, lambda x: env_lookup(env, x, user) - ) - bucket_id_is_scheduled[bucket_id] = True + new_sharded_grads.append(new_sharded_grad) + padded_sharded_numel = padded_unsharded_size.numel() // group_size # type: ignore[operator] + flat_grad_offset += padded_sharded_numel # type: ignore[assignment] + assert len(orig_wait_nodes) == len(new_sharded_grads) + assert len(orig_wait_nodes) > 0 + for new_sharded_grad, orig_wait_node in zip( + new_sharded_grads, orig_wait_nodes + ): + env[orig_wait_node] = new_sharded_grad # noqa: PERF403 + for user in sorted(orig_wait_node_recursive_users, key=lambda x: order[x]): + # We skip output node here, because output node will be inserted (later) + # as the last node in the new graph. + if user.op != "output": + node_copy(env, new_graph, user, lambda x: env_lookup(env, x, user)) + bucket_id_is_scheduled[bucket_id] = True else: continue assert node_list[-1].op == "output" diff --git a/torch/_inductor/fx_passes/decompose_mem_bound_mm.py b/torch/_inductor/fx_passes/decompose_mem_bound_mm.py index e6757c3ad9e3..30cfcdd615fb 100644 --- a/torch/_inductor/fx_passes/decompose_mem_bound_mm.py +++ b/torch/_inductor/fx_passes/decompose_mem_bound_mm.py @@ -15,11 +15,17 @@ log = logging.getLogger(__name__) # TODO: need a better strategy for decomposing mm +# The following two constants are for CUDA device only MIN_FIRST_DIMENSION_DECOMPOSITION = 10240 MAX_OTHER_DIMENSION_DECOMPOSITION = 32 +# The following two constants are for CPU device only +CPU_MAX_FIRST_DIMENSION_DECOMPOSITION = 1 +CPU_MAX_OTHER_DIMENSION_DECOMPOSITION = 2048 min_first_dimension_decomposition = MIN_FIRST_DIMENSION_DECOMPOSITION max_other_dimension_decomposition = MAX_OTHER_DIMENSION_DECOMPOSITION +cpu_max_first_dimension_decomposition = CPU_MAX_FIRST_DIMENSION_DECOMPOSITION +cpu_max_other_dimension_decomposition = CPU_MAX_OTHER_DIMENSION_DECOMPOSITION if "decompose_mm_pass" in config.post_grad_fusion_options: min_first_dimension_decomposition = config.post_grad_fusion_options[ "decompose_mm_pass" @@ -27,6 +33,16 @@ max_other_dimension_decomposition = config.post_grad_fusion_options[ "decompose_mm_pass" ].get("max_other_dimension_decomposition", MAX_OTHER_DIMENSION_DECOMPOSITION) + cpu_max_first_dimension_decomposition = config.post_grad_fusion_options[ + "decompose_mm_pass" + ].get( + "cpu_max_first_dimension_decomposition", CPU_MAX_FIRST_DIMENSION_DECOMPOSITION + ) + cpu_max_other_dimension_decomposition = config.post_grad_fusion_options[ + "decompose_mm_pass" + ].get( + "cpu_max_other_dimension_decomposition", CPU_MAX_OTHER_DIMENSION_DECOMPOSITION + ) def check_device(a: Tensor, b: Tensor, device="cuda") -> bool: @@ -57,7 +73,10 @@ def should_decompose_bmm(mat1, mat2) -> bool: return False return True elif check_device(mat1, mat2, device="cpu"): - if mat1.shape[0] == 1 and mat2.shape[0] == 1: + if ( + mat1.shape[0] <= cpu_max_first_dimension_decomposition + and mat2.shape[0] <= cpu_max_first_dimension_decomposition + ): return True return False @@ -77,9 +96,15 @@ def should_decompose_mm(mat1, mat2) -> bool: and statically_known_true(mat2.shape[1] < max_other_dimension_decomposition) ) or ( check_device(mat1, mat2, device="cpu") - and statically_known_true(mat1.shape[0] == 1) - and statically_known_true(mat2.shape[0] <= 128) - and statically_known_true(mat2.shape[1] <= 512) + and statically_known_true( + mat1.shape[0] <= cpu_max_first_dimension_decomposition + ) + and statically_known_true( + mat2.shape[0] <= cpu_max_other_dimension_decomposition + ) + and statically_known_true( + mat2.shape[1] <= cpu_max_other_dimension_decomposition + ) ) diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py index 4ed950afe9a1..5f449eb49664 100644 --- a/torch/_inductor/fx_passes/fuse_attention.py +++ b/torch/_inductor/fx_passes/fuse_attention.py @@ -18,7 +18,6 @@ log = logging.getLogger(__name__) aten = torch.ops.aten - _scaled_dot_product_attention = aten.scaled_dot_product_attention @@ -582,6 +581,42 @@ def _sfdp_replacement_20(query, key, value, attn_mask, dropout_p): ) +def _sfdp_pattern_24(query, key, value, attention_mask): + """ + this pattern is for MBartForCausalLM/PLBartForCausalLM. + attn_mask has a different dtype with QKV. + there is no scale in sdpa. + """ + bs = query.size(0) + n_head = query.size(1) + seq_len = query.size(2) + head_size = query.size(3) + q = query.view(bs * n_head, -1, head_size) + k = key.reshape(bs * n_head, -1, head_size) + v = value.reshape(bs * n_head, -1, head_size) + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = attn_weights.view(bs, n_head, seq_len, -1) + attention_mask + attn_weights = attn_weights.view(bs * n_head, seq_len, -1) + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + if query.dtype == torch.half: + attn_weights = attn_weights.to(torch.half) + attn_output = torch.bmm(attn_weights, v) + attn_output = attn_output.view(bs, n_head, seq_len, head_size) + return attn_output + + +def _sfdp_replacement_24(query, key, value, attention_mask): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask.to(dtype=query.dtype), + is_causal=False, + scale=1, + ) + + def _sfdp_pattern_21(query, key, value, attn_mask): # for T5 with inplace add query = query.permute([0, 2, 1, 3]) @@ -1003,6 +1038,13 @@ def _get_sfdp_patterns(): {}, _sfdp_params_check, ), + ( + _sfdp_pattern_24, + _sfdp_replacement_24, + [g(), g(), g(), b_float()], + {}, + _sfdp_extra_check, + ), ] mask_fp32_patterns = ["pattern_16"] if dtype == torch.half: diff --git a/torch/_inductor/fx_passes/micro_pipeline_tp.py b/torch/_inductor/fx_passes/micro_pipeline_tp.py index af40d987f7d1..c4d935a4f8bb 100644 --- a/torch/_inductor/fx_passes/micro_pipeline_tp.py +++ b/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -850,9 +850,11 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: Returns boolean indicating if fusion was successful or not. """ - assert torch.distributed.is_available() and torch.distributed.is_nccl_available(), ( - "torch.distributed and NCCL must be available to use async tensor parallelism" - ) + if ( + not torch.distributed.is_available() + or not torch.distributed.is_nccl_available() + ): + return from torch.distributed._symmetric_memory import ( is_symm_mem_enabled_for_group, @@ -875,9 +877,8 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: reduce_scatter.group_name, ) - assert is_symm_mem_enabled_for_group(group_name), ( - f"symmetric memory is not enabled for process group {group_name}, this is required for async TP" - ) + if not is_symm_mem_enabled_for_group(group_name): + return # Currently fused_matmul_reduce_scatter doesn't return the matmul result, # so we can't apply the fusion if the matmul result is used by multiple diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index a269b17e3a2a..e5a0c0dc51c5 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -1228,11 +1228,14 @@ def is_const_or_cat_by_const(weight): torch.bfloat16, torch.float16, ) - bf32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined] - use_bf16_for_fp32_weight = ( - bf32_matmul_enabled and weight_meta_value.dtype == torch.float32 + reduced_f32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision in [ # type: ignore[attr-defined] + "bf16", + "tf32", + ] + use_reduced_f32_for_fp32_weight = ( + reduced_f32_matmul_enabled and weight_meta_value.dtype == torch.float32 ) - compute_with_lp = is_lp_weight or use_bf16_for_fp32_weight + compute_with_lp = is_lp_weight or use_reduced_f32_for_fp32_weight # on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol. # on aarch64, use mkldnn op for fp32 as well if acl is enabled if ( @@ -1449,13 +1452,13 @@ def linear(match, *args, **kwargs): torch.bfloat16, torch.float16, ) - bf32_matmul_enabled = ( - torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined] + reduced_f32_matmul_enabled = ( + torch.backends.mkldnn.matmul.fp32_precision in ["bf16", "tf32"] # type: ignore[attr-defined] ) - use_bf16_for_fp32_weight = ( - bf32_matmul_enabled and weight_dtype == torch.float32 + use_reduced_f32_for_fp32_weight = ( + reduced_f32_matmul_enabled and weight_dtype == torch.float32 ) - compute_with_lp = is_lp_weight or use_bf16_for_fp32_weight + compute_with_lp = is_lp_weight or use_reduced_f32_for_fp32_weight batch_size = input.meta.get("val").shape[0] if has_free_symbols(batch_size): assert compute_with_lp or mkldnn._is_mkldnn_acl_supported(), ( diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 862df99a41e5..70dfe9ae43b3 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -72,7 +72,13 @@ def _get_pattern_output_dtype(match: Match): output_node = pattern_output_nodes[0] assert isinstance(output_node, torch.fx.Node) output_dtype = output_node.meta["val"].dtype - assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] + assert output_dtype in [ + torch.int8, + torch.uint8, + torch.float32, + torch.bfloat16, + torch.float8_e4m3fn, + ] return output_dtype diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_24.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_24.py new file mode 100644 index 000000000000..72f23373c143 --- /dev/null +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_24.py @@ -0,0 +1,153 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +view_default = CallFunction(aten.view.default, KeywordArg('query'), Ignored(), _users=2) +view_default_1 = CallFunction(aten.view.default, KeywordArg('key'), Ignored()) +permute_default = CallFunction(aten.permute.default, view_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, permute_default) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attention_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_3, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_3, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=4) +view_default_4 = CallFunction(aten.view.default, KeywordArg('value'), Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, div_Tensor, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +mul_Tensor = CallFunction(aten.mul.Tensor, bmm_default_2, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_7 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, view_default_7, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +view_default_10 = CallFunction(aten.view.default, permute_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, div_Tensor, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_24_training = MultiOutputPattern([view_default_5, + view_default_9, + view_default_10, + view_default_11, + None +]) + + +view_default = CallFunction(aten.view.default, KeywordArg('query'), Ignored()) +view_default_1 = CallFunction(aten.view.default, KeywordArg('key'), Ignored()) +permute_default = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, permute_default) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attention_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_3, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_3, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +view_default_4 = CallFunction(aten.view.default, KeywordArg('value'), Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, div_Tensor, view_default_4) +_sfdp_pattern_24_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +view_default = CallFunction(aten.view.default, KeywordArg('query'), Ignored(), _users=2) +view_default_1 = CallFunction(aten.view.default, KeywordArg('key'), Ignored()) +permute_default = CallFunction(aten.permute.default, view_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, permute_default) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attention_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_3, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_3, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +view_default_4 = CallFunction(aten.view.default, KeywordArg('value'), Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, convert_element_type_default, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_1, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_7 = CallFunction(aten.view.default, fma_default, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +view_default_8 = CallFunction(aten.view.default, convert_element_type_default_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +view_default_10 = CallFunction(aten.view.default, permute_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, convert_element_type_default, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_24_half_training = MultiOutputPattern([view_default_5, + view_default_9, + view_default_10, + view_default_11, + None +]) + + +view_default = CallFunction(aten.view.default, KeywordArg('query'), Ignored()) +view_default_1 = CallFunction(aten.view.default, KeywordArg('key'), Ignored()) +permute_default = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, permute_default) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attention_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_3, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_3, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, KeywordArg('value'), Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, convert_element_type_default, view_default_4) +_sfdp_pattern_24_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index ba0529a5fad9..327f96ae34ac 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -62,6 +62,7 @@ "split_stack_to_cats_pass", "unbind_stack_to_slices_pass", "move_reshape_out_of_split_stack_pass", + "einsum_to_pointwise_pass", ] post_grad_pass_names = [ @@ -1790,7 +1791,11 @@ def merge_split_cat_aten(match: Match, *args, **kwargs): for cat_node in list(getitem_nodes[0].users.keys()): cat_dim = get_arg_value(cat_node, 1, "dim") cat_inputs = get_arg_value(cat_node, 0, "tensors") - if len(cat_inputs) < threshold_to_cat: + try: + cat_input_len = len(cat_inputs) + except TypeError: + continue + if cat_input_len < threshold_to_cat: continue # check split node and cat node has same dim, and all getitem nodes have same parent node parent_to_indices = defaultdict(list) # type: ignore[var-annotated] @@ -2965,3 +2970,65 @@ def move_view_after_cat(match: Match, *args, **kwargs): view_node.meta.update(cat_node.meta) graph.erase_node(cat_node) counters["inductor"]["move_view_after_cat_aten_pass"] += 1 + + +def match_einsum_strings(s: str) -> bool: + """ + This function takes a string s as input, where s is in the format "3 letter string, + 4 letter string -> 3 letter string". + It checks if the strings match the rule and returns True if they do, False otherwise. + + The rule is: + - The three strings have the same first two characters. + - The first two strings have the same third character. + - The second and third strings have the same last character. + """ + + # Split the input string into parts + parts = s.replace("->", ",").split(",") + + # Strip leading/trailing whitespaces from each part + parts = [part.strip() for part in parts] + + # Check if we have exactly three parts + if len(parts) != 3: + return False + + # Extract the strings + s1, s2, s3 = parts + + # Check if the strings have the correct lengths + if len(s1) != 3 or len(s2) != 4 or len(s3) != 3: + return False + + # Check the rule + return s1[:2] == s2[:2] == s3[:2] and s1[2] == s2[2] and s2[3] == s3[2] + + +@register_graph_pattern( + CallFunctionVarArgs(torch.functional.einsum, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("einsum_to_pointwise_pass"), +) +def replace_einsum_to_pointwise(match: Match, *args, **kwargs): + def repl(input, weights): + return (input.unsqueeze(-1) * weights).sum(-2) + + def should_replace_einsum(einsum_node) -> bool: + equation = get_arg_value(einsum_node, 0) + users = einsum_node.users.keys() + # for now, we only consider the case of two operands + return ( + len(einsum_node.args) == 3 + and is_node_meta_valid(input) + and is_node_meta_valid(weights) + and any( + user.target == "add" or user.target == operator.add for user in users + ) + and match_einsum_strings(equation) + ) + + einsum_node = match.nodes[0] + input, weights = get_arg_value(einsum_node, 1), get_arg_value(einsum_node, 2) + if should_replace_einsum(einsum_node): + match.replace_by_example(repl, [input, weights]) + counters["inductor"]["einsum_to_pointwise_pass"] += 1 diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index e2cc101533f2..ac299d5b0c2d 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -123,6 +123,7 @@ from torch.fx.graph import Graph from .codegen.wrapper import PythonWrapperCodegen + from .dependencies import Dep from .scheduler import BaseSchedulerNode CompiledModule = Union[ModuleType, FileBackedGraphModule] @@ -485,6 +486,9 @@ def __init__( self.bw_donated_idxs = get_donated_idxs() + # Cache for dep size hints to avoid expensive recomputation + self.dep_size_hint_cache: dict[Dep, int] = {} + def freeze_runtime_asserts(self) -> None: self._shape_env.freeze_runtime_asserts() @@ -570,6 +574,23 @@ def has_feature( assert isinstance(feature, BackendFeature), feature return feature in self.get_backend_features(get_device_type(device)) + def get_dep_size_hint(self, dep: Dep) -> int: + """ + Get the size hint for a dependency with caching to avoid expensive recomputation. + """ + if dep not in self.dep_size_hint_cache: + res = 0 + try: + if not dep.has_unbacked_symbols(): + res = dep.numbytes_hint() + except KeyError: + # In at least one test (test/inductor/test_torchbind.py) we + # create a StarDep that doesn't exist in the graph and calling + # `has_unbacked_symbols()` throws an error. + pass + self.dep_size_hint_cache[dep] = res + return self.dep_size_hint_cache[dep] + def get_current_device_or_throw(self) -> torch.device: if device := self.current_device: return device diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 1edbb214ae2a..a21b9c50938e 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -541,12 +541,23 @@ def get_symbolic_inputs(inputs: Sequence[IRNode]) -> list[Expr]: class IRNode: + """Base class for all intermediate representation (IR) nodes in TorchInductor. + + Note: + This is an abstract base class. Most methods raise NotImplementedError + and must be overridden by concrete subclasses. + """ + _current_origins: ClassVar[OrderedSet[Any]] = OrderedSet() # NB: These are kinda weird, origins: OrderedSet[Any] = dataclasses.field(init=False) + # traces back to where the IRNode is created in Inductor traceback: Optional[list[str]] = dataclasses.field(init=False) origin_node: Optional[torch.fx.Node] = dataclasses.field(init=False) + # trace backs to user model code + # a single IRNode could correspond to multiple lines of code + stack_traces: dict[str, str] = dataclasses.field(init=False) @staticmethod @contextlib.contextmanager @@ -578,12 +589,41 @@ def _post_init_setattr(self, attr: str, value: Any) -> None: object.__setattr__(self, attr, value) def __post_init__(self) -> None: - self._post_init_setattr("origins", OrderedSet(self._current_origins)) + origins = OrderedSet(self._current_origins) + self._post_init_setattr("origins", origins) self._post_init_setattr( "traceback", traceback.format_stack() if config.debug_ir_traceback else None ) self._post_init_setattr("origin_node", None) + # Group nodes by their stack traces to deduplicate + nodes_to_stack_trace = {} + if config.trace.provenance_tracking: + for node in origins: + if node.stack_trace: + # nodes in the backward graph don't have mapping to pre_grad_graph + nodes_to_stack_trace["post_grad+" + node.name] = node.stack_trace + else: + if ( + "postToPre" + not in torch._inductor.debug._inductor_post_to_pre_grad_nodes + ): + continue + node_names = torch._inductor.debug._inductor_post_to_pre_grad_nodes[ + "postToPre" + ].get(node.name, None) + if node_names: + for node_name in node_names: + stack_trace = torch._inductor.debug._inductor_pre_grad_node_stack_trace.get( + node_name, None + ) + if stack_trace: + nodes_to_stack_trace["pre_grad+" + node_name] = ( + stack_trace + ) + + self._post_init_setattr("stack_traces", nodes_to_stack_trace) + def get_read_names(self) -> OrderedSet[str]: return OrderedSet(dep.name for dep in self.get_reads()) @@ -601,7 +641,15 @@ def common_repr(self, shorten: bool = True) -> Sequence[str]: if shorten and len(origins) > 64: # this can get *very* long origins = f"{origins[:61]}..." - return [origins] + if not self.stack_traces: + return [origins] + + stack_trace_str = [] + for stack_trace in self.stack_traces.values(): + stack_trace_str.append("stack_traces = {{") + stack_trace_str += stack_trace.split("\n") + stack_trace_str.append("}") + return [origins] + stack_trace_str def str_helper( self, lines: Sequence[object], shorten: bool = True, multiline: bool = True @@ -2828,7 +2876,7 @@ def _normalize_size(x: IRNode, new_size: Sequence[_IntLike]) -> Sequence[_IntLik assert old_size[i] is not None new_size[i] = old_size[i] elif old_size[i] is None or V.graph.sizevars.shape_env.evaluate_expr( - sympy.Eq(old_size[i], 1), size_oblivious=True + sympy.Eq(old_size[i], 1), fallback_value=False ): pass else: @@ -2855,7 +2903,7 @@ def create(cls, x: IRNode, new_size: Sequence[_IntLike]) -> BaseView: new_stride.append( stride if not V.graph.sizevars.shape_env.evaluate_expr( - sympy.Eq(size, 1), size_oblivious=True + sympy.Eq(size, 1), fallback_value=False ) else sympy.S.Zero ) @@ -7829,6 +7877,10 @@ def create(data: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]: class StorageBox(MutableBox): + """ + StorageBox allow in-place mutation of Tensors + """ + def is_input_buffer(self) -> bool: if isinstance(self.data, (InputBuffer, ReinterpretView)): return self.data.get_name() in V.graph.graph_inputs @@ -7878,10 +7930,21 @@ def realize_hint(self) -> None: ): self.realize() + def has_accumulated_enough_reads_by_size(self, threshold: int) -> bool: + return ( + sum(V.graph.get_dep_size_hint(dep) for dep in self.get_reads()) > threshold + ) + def has_exceeded_max_reads(self) -> bool: return isinstance(self.data, Pointwise) and ( self.num_reads() > config.realize_acc_reads_threshold or self.has_large_inner_fn() + or ( + config.realize_acc_reads_size_threshold is not None + and self.has_accumulated_enough_reads_by_size( + config.realize_acc_reads_size_threshold + ) + ) ) def should_realize_on_reuse(self, users: int) -> bool: diff --git a/torch/_inductor/kernel/mm_grouped.py b/torch/_inductor/kernel/mm_grouped.py index d311b62950bd..19ca389c2a53 100644 --- a/torch/_inductor/kernel/mm_grouped.py +++ b/torch/_inductor/kernel/mm_grouped.py @@ -22,6 +22,7 @@ get_num_sms, has_free_symbols, use_aten_gemm_kernels, + use_triton_template, ) from .mm_common import ( _is_static_problem, @@ -434,23 +435,30 @@ def grouped_mm_args( if out_dtype is None: out_dtype = mat1.get_dtype() + alignment = 16 // out_dtype.itemsize - dims = [] if m1dim == 2: if m2dim == 2: assert offs is not None - dims = [offs.get_size()[0], mat1_size[0], mat2_size[1]] + out_size = [offs.get_size()[0], mat1_size[0], mat2_size[1]] else: - dims = [mat1_size[0], mat2_size[-1]] + out_size = [mat1_size[0], mat2_size[-1]] else: if m2dim == 2: - dims = [mat1_size[1], mat2_size[1]] + out_size = [mat1_size[1], mat2_size[1]] else: - dims = [mat1_size[0], mat1_size[1], mat2_size[-1]] + out_size = [mat1_size[0], mat1_size[1], mat2_size[-1]] + size_padded = (out_size[-1] + alignment - 1) // alignment * alignment + if len(out_size) == 2: + out_stride = [size_padded, 1] + else: + out_stride = [out_size[1] * size_padded, size_padded, 1] + layout = FixedLayout( mat1.get_device(), out_dtype, - dims, + out_size, + out_stride, ) else: assert out_dtype is None, "out_dtype is ignored if layout is specified." @@ -604,7 +612,11 @@ def _tuned_grouped_mm_common( # Checking only for the equality of corresponding dims of # multiplicands here, relying on meta function checks for # everything else. - if is_nonzero and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result): + if ( + is_nonzero + and use_triton_template(layout) + and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result) + ): scaled = scale_a is not None if len(m1_size) == 2: if len(m2_size) == 2: diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index c4c8f70003c6..503795bc513c 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -489,11 +489,11 @@ def broadcast_symbolic_shapes(a, b): output = [] for x, y in itertools.zip_longest(reversed(a), reversed(b), fillvalue=sympy.S.One): if V.graph.sizevars.shape_env.evaluate_expr( - sympy.Eq(y, 1), size_oblivious=True + sympy.Eq(y, 1), fallback_value=False ): output.append(x) elif V.graph.sizevars.shape_env.evaluate_expr( - sympy.Eq(x, 1), size_oblivious=True + sympy.Eq(x, 1), fallback_value=False ): output.append(y) else: @@ -939,26 +939,14 @@ def broadcast_tensors(*inputs): outputs = [] for x in inputs: sizes = x.get_size() - if len(sizes) != len(target) or any( - ( - ( - V.graph.sizevars.shape_env.evaluate_expr( - sympy.Eq(a, 1), size_oblivious=True - ) - and not V.graph.sizevars.shape_env.evaluate_expr( - sympy.Eq(b, 1), size_oblivious=True - ) - ) - or ( - not V.graph.sizevars.shape_env.evaluate_expr( - sympy.Eq(a, 1), size_oblivious=True - ) - and V.graph.sizevars.shape_env.evaluate_expr( - sympy.Eq(b, 1), size_oblivious=True - ) - ) + + def is_length_one(size: sympy.Expr): + return V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(size, 1), fallback_value=False ) - for a, b in zip(sizes, target) + + if len(sizes) != len(target) or any( + is_length_one(a) != is_length_one(b) for a, b in zip(sizes, target) ): x = expand(x, target) outputs.append(x) diff --git a/torch/_inductor/memory.py b/torch/_inductor/memory.py index 5601bc4adcee..d287208419a9 100644 --- a/torch/_inductor/memory.py +++ b/torch/_inductor/memory.py @@ -78,19 +78,8 @@ def get_freeable_input_buf( A dictionary containing all freeble input buffers, keyed by their names. """ - # this function is copied from torch/_inductor/scheduler.py - # TODO: would be nice to remove the try/except block for both places def _dep_size_hint(dep: Dep) -> int: - res = 0 - try: - if not dep.has_unbacked_symbols(): - res = dep.numbytes_hint() - except KeyError: - # In at least one test (test/inductor/test_torchbind.py) we - # create a StarDep that doesn't exist in the graph and calling - # `has_unbacked_symbols()` throws an error. - pass - return res + return V.graph.get_dep_size_hint(dep) # get freeable input buffers' successor nodes and their sizes # note that different deps can have the same name, so we use name as keys diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index e7981bc8746b..3b3a7b072534 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -675,8 +675,8 @@ def qlinear_unary( algorithm, layout=None, ): - assert packed_weight.get_dtype() is torch.int8, ( - "Only int8 weights are supported by oneDNN qlinear." + assert packed_weight.get_dtype() in [torch.int8, torch.float8_e4m3fn], ( + "Only int8 and e4m3fn weights are supported by oneDNN qlinear." ) x_size = x.get_size() if len(x_size) > 2: diff --git a/torch/_inductor/package/package.py b/torch/_inductor/package/package.py index 726b41d97240..bd11d033cadb 100644 --- a/torch/_inductor/package/package.py +++ b/torch/_inductor/package/package.py @@ -105,7 +105,7 @@ def load_package( run_single_threaded: bool = False, num_runners: int = 1, device_index: int = -1, -) -> AOTICompiledModel: # type: ignore[type-arg] +) -> AOTICompiledModel: try: pt2_contents = load_pt2( path, diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 1da31586b0a1..772ddcced96f 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -127,7 +127,7 @@ def _transfer_meta( # transfer metadata after pattern matching occurs. # skip "val" and "tensor_meta" because this info is too specific; it's unlikely # to remain accurate after pattern matching has occurred. - if config.trace.enabled: + if config.trace.provenance_tracking: # We handle "from_node" field of the node meta specially to record that the new node comes from the old_node. new_from_node = new_meta.get("from_node", []).copy() new_from_node.append(NodeSource(old_node, pass_name, NodeSourceAction.REPLACE)) @@ -143,6 +143,8 @@ def _transfer_meta( for k, v in old_node.meta.items() if k in torch.fx.proxy._COPY_META_FIELDS ) + if "stack_trace" in old_node.meta: + new_meta["stack_trace"] = old_node.meta["stack_trace"] class Match: @@ -318,7 +320,12 @@ def record(node: torch.fx.Node, val: Any) -> None: ] else: - example_vals = torch.fx.map_arg(args, lambda arg: arg.meta["val"]) + example_vals = torch.fx.map_arg( + args, + lambda arg: arg.meta["val"] + if "val" in arg.meta + else arg.meta["example_value"], + ) replacement = trace_fn(replacement_fn, example_vals) if len(self.nodes) == 1: for n in replacement.graph.nodes: diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index aaa266b60e00..1304ce79b86e 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -170,10 +170,13 @@ def get(self, key: str) -> Optional[_T]: try: result = self._get(key, sample) cache_stats.get(type(self).__name__, result) - except Exception: + except Exception as e: cache_stats.exception(type(self).__name__) + if sample: + sample.fail_reason = str(e) raise - self._log_sample(sample) + finally: + self._log_sample(sample) return result # Add `value` to the cache with the key `key`. Note that `None` is not a @@ -186,10 +189,13 @@ def put(self, key: str, value: _T) -> None: try: self._put(key, value, sample) cache_stats.put(type(self).__name__) - except Exception: + except Exception as e: cache_stats.exception(type(self).__name__) + if sample: + sample.fail_reason = str(e) raise - self._log_sample(sample) + finally: + self._log_sample(sample) # Used to convert data from the cache into structured data. def _decode(self, data: _U, sample: Optional[Sample]) -> _T: # type: ignore[override] diff --git a/torch/_inductor/runtime/autotune_cache.py b/torch/_inductor/runtime/autotune_cache.py index 01d038aab8e7..88b9c80c7714 100644 --- a/torch/_inductor/runtime/autotune_cache.py +++ b/torch/_inductor/runtime/autotune_cache.py @@ -35,6 +35,7 @@ from typing_extensions import override import torch +from torch._dynamo.precompile_context import PrecompileContext from torch._inductor.runtime.runtime_utils import cache_dir from torch.compiler._cache import ( CacheArtifact, @@ -125,6 +126,7 @@ def create( ) -> Optional[AutotuneCache]: cache = AutotuneCache(configs_hash) key = AutotuneCache._prepare_key(filename) + cache._setup_local_cache(inductor_meta, os.path.dirname(filename), key) cache._setup_remote_autotune_cache(inductor_meta, key) if cache.local_cache or cache.remote_cache: @@ -300,6 +302,10 @@ def save( CacheArtifactManager.record_artifact( AutotuneCacheArtifact.type(), autotune_artifact_key, data ) + if torch._dynamo.config.caching_precompile: + PrecompileContext.record_artifact( + AutotuneCacheArtifact.type(), autotune_artifact_key, data + ) if log.isEnabledFor(logging.DEBUG): type_str = "coordesc" if found_by_coordesc else "heuristic" @@ -625,6 +631,10 @@ def _get(self, key: str, sample: Optional[Sample]) -> Optional[JsonDataTy]: CacheArtifactManager.record_artifact( AutotuneCacheArtifact.type(), autotune_artifact_key, result ) + if torch._dynamo.config.caching_precompile: + PrecompileContext.record_artifact( + AutotuneCacheArtifact.type(), autotune_artifact_key, result + ) return result @override diff --git a/torch/_inductor/runtime/triton_compat.py b/torch/_inductor/runtime/triton_compat.py index 877f72b50c55..645e0f4c8903 100644 --- a/torch/_inductor/runtime/triton_compat.py +++ b/torch/_inductor/runtime/triton_compat.py @@ -69,7 +69,18 @@ def GPUTarget( def _log2(x: Any) -> Any: raise NotImplementedError - HAS_WARP_SPEC = hasattr(tl, "async_task") + def _triton_config_has(param_name: str) -> bool: + if not hasattr(triton, "Config"): + return False + if not hasattr(triton.Config, "__init__"): + return False + return param_name in inspect.signature(triton.Config.__init__).parameters + + HAS_WARP_SPEC = ( + hasattr(tl, "async_task") + and _triton_config_has("num_consumer_groups") + and _triton_config_has("num_buffers_warp_spec") + ) try: from triton import knobs diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 5c7a16d25bc6..f3986b897161 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2051,15 +2051,12 @@ class Scheduler: optimizations such as fusion, reorder, and graph partition. """ - __dep_size_hint_cache: dict[Dep, int] - def __init__(self, nodes: list[ir.Operation]) -> None: with dynamo_timed("Scheduler.__init__"): self._init(nodes) def _init(self, nodes: list[ir.Operation]) -> None: super().__init__() - self.__dep_size_hint_cache = {} V.graph.scheduler = self self.backends: dict[torch.device, BaseScheduling] = {} self.post_grad_graph_id = next(_post_grad_graph_counter) @@ -3505,6 +3502,14 @@ def _find_single_user_inputs( return True return False + def fusion_accumulate_large_reads( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode, threshold: int + ) -> bool: + all_reads = (node1.read_writes.reads | node2.read_writes.reads) - ( + node1.read_writes.writes | node2.read_writes.writes + ) + return sum(self.dep_size_hint(dep) for dep in all_reads) > threshold + def are_long_distant_nodes( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode ) -> bool: @@ -4010,20 +4015,7 @@ def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool: return False def dep_size_hint(self, dep: Dep) -> int: - res = 0 - if dep not in self.__dep_size_hint_cache: - try: - if not dep.has_unbacked_symbols(): - res = dep.numbytes_hint() - except KeyError: - # In at least one test (test/inductor/test_torchbind.py) we - # create a StarDep that doesn't exist in the graph and calling - # `has_unbacked_symbols()` throws an error. - pass - self.__dep_size_hint_cache[dep] = res - else: - res = self.__dep_size_hint_cache[dep] - return res + return V.graph.get_dep_size_hint(dep) def score_fusion_memory( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index c7c49333d1ae..903d616bb91e 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -167,11 +167,25 @@ class PartialRender: of replacements after the initial render. """ + FINALIZED_HOOK: object = object() + def __init__(self, code, replacement_hooks) -> None: super().__init__() - self.code = code + self._code = code self.replacement_hooks = replacement_hooks + @property + def code(self): + remaining_active_hooks = [ + key + for key, fn in self.replacement_hooks.items() + if fn is not self.FINALIZED_HOOK + ] + assert len(remaining_active_hooks) == 0, ( + f"The following hooks have not yet been finalized:\n {remaining_active_hooks=}" + ) + return self._code + def finalize_hook(self, hook_key: str, strict=True) -> None: if hook_key not in self.replacement_hooks: if strict: @@ -180,15 +194,28 @@ def finalize_hook(self, hook_key: str, strict=True) -> None: ) else: return - assert self.replacement_hooks[hook_key] is not None, ( + assert self.replacement_hooks[hook_key] is not self.FINALIZED_HOOK, ( "hook_key can only be called once" ) - self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]()) - self.replacement_hooks[hook_key] = None + self._code = self._code.replace(hook_key, self.replacement_hooks[hook_key]()) + self.replacement_hooks[hook_key] = self.FINALIZED_HOOK - def finalize_all(self) -> str: + def finalize_remaining(self) -> str: + """ + Finalize the remaining active hooks. This function can be used in cases + where the caller uses `finalize_hook` rather than `finalize_all`. + Note: `finalize_all` errors if a hook that has already been finalized + is attempted to be called again. This function only attempts to + finalize active hooks. + """ for key, fn in self.replacement_hooks.items(): - self.code = self.code.replace(key, fn()) + if fn is not self.FINALIZED_HOOK: + self.finalize_hook(key) + return self.code + + def finalize_all(self) -> str: + for key in self.replacement_hooks: + self.finalize_hook(key) return self.code @@ -2310,20 +2337,7 @@ def autotune(choices, hint_override: Optional[int] = None): f"{name}_template_autotuning", log_pt2_compile_event=True, dynamo_compile_column_us="compile_time_autotune_time_us", - metadata={ - "autotune_strides": ", ".join( - [str(n.get_stride()) for n in input_nodes] - ), - "autotune_dtypes": ", ".join( - [str(n.get_dtype()) for n in input_nodes] - ), - "autotune_shape": ", ".join( - ["x".join(map(str, n.get_size())) for n in input_nodes] - ), - "autotune_offset": ", ".join( - [str(n.get_layout().offset) for n in input_nodes] - ), - }, + metadata=_autotune_metadata(input_nodes), ): return benchmark(choices, hint_override=hint_override) @@ -3343,5 +3357,44 @@ def sympy_call(self, *args, **kwargs): return self.fn(*args, **kwargs, **self.kwargs_sym) +def _autotune_metadata(input_nodes): + """Helper function to extract autotune metadata from input nodes.""" + return { + "autotune_strides": ", ".join([str(n.get_stride()) for n in input_nodes]), + "autotune_dtypes": ", ".join([str(n.get_dtype()) for n in input_nodes]), + "autotune_shape": ", ".join( + ["x".join(map(str, n.get_size())) for n in input_nodes] + ), + "autotune_offset": ", ".join([str(n.get_layout().offset) for n in input_nodes]), + # TODO(coconutruben): replace this with taking KernelInputs as the + # argument, and extracting those out there directly + "autotune_strides_hinted": ", ".join( + [ + str( + V.graph.sizevars.size_hints( + n.get_stride(), + fallback=config.unbacked_symint_fallback, + ) + ) + for n in input_nodes + ] + ), + "autotune_shape_hinted": ", ".join( + [ + "x".join( + map( + str, + V.graph.sizevars.size_hints( + n.get_size(), + fallback=config.unbacked_symint_fallback, + ), + ) + ) + for n in input_nodes + ] + ), + } + + # ensure lowering is imported so that `extern_kernels.*` is populated from . import lowering # noqa: F401 diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py index 40a964518679..65a6851192a0 100644 --- a/torch/_inductor/template_heuristics.py +++ b/torch/_inductor/template_heuristics.py @@ -707,6 +707,18 @@ class CUDAConfigHeuristic(BaseConfigHeuristic): def __init__(self) -> None: super().__init__() + self.b200_default_flex_config = { + (torch.float32, 64): FlexConfig(128, 32, 3, 4), + (torch.float32, 128): FlexConfig(32, 64, 3, 4), + (torch.float32, 256): FlexConfig(32, 32, 3, 4), + (torch.bfloat16, 64): FlexConfig(128, 128, 3, 4), + (torch.bfloat16, 128): FlexConfig(128, 64, 2, 8), + (torch.bfloat16, 256): FlexConfig(64, 32, 3, 4), + (torch.float16, 64): FlexConfig(128, 128, 3, 4), + (torch.float16, 128): FlexConfig(128, 128, 3, 8), + (torch.float16, 256): FlexConfig(64, 32, 3, 4), + } + self.h100_default_flex_config = { (torch.float32, 64): FlexConfig(128, 32, 3, 4), (torch.float32, 128): FlexConfig(32, 64, 3, 4), @@ -745,7 +757,11 @@ def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfi default_config = FlexConfig(64, 64, 3, 4) else: default_config = FlexConfig(128, 64, 3, 4) - if capability >= (9, 0): + if capability >= (10, 0): + default_config = self.b200_default_flex_config.get( + (dtype, head_dim), default_config + ) + elif capability >= (9, 0): default_config = self.h100_default_flex_config.get( (dtype, head_dim), default_config ) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 10701d0d8b2d..aef81712d17e 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -81,15 +81,7 @@ from .codegen.common import WorkspaceArg from .codegen.wrapper import PythonWrapperCodegen from .graph import GraphLowering - from .ir import ( - Buffer, - ExternKernel, - ExternKernelOut, - IRNode, - Layout, - Operation, - ReinterpretView, - ) + from .ir import Buffer, ExternKernel, IRNode, Layout, Operation, ReinterpretView from .output_code import CompiledFxGraph from .scheduler import BaseSchedulerNode, SchedulerBuffer @@ -1099,13 +1091,17 @@ def fresh_cache( """ clear_caches() - inductor_cache_dir = tempfile.mkdtemp(dir=dir) + from torch._inductor.cpp_builder import normalize_path_separator + + inductor_cache_dir = normalize_path_separator(tempfile.mkdtemp(dir=dir)) try: with mock.patch.dict( os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir} ): log.debug("Using inductor cache dir %s", inductor_cache_dir) - triton_cache_dir = os.path.join(inductor_cache_dir, "triton") + triton_cache_dir = normalize_path_separator( + os.path.join(inductor_cache_dir, "triton") + ) with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}): yield if isinstance(cache_entries, dict): @@ -1539,35 +1535,87 @@ def use_triton_template( ) -def use_triton_tma_template(*matrices: IRNode) -> bool: +def use_triton_tma_template(*matrices: IRNode, add_guards: bool = False) -> bool: + """ + Return True iff *all* supplied tensors satisfy the CUDA-12.9 TMA constraints + that Triton relies on today. + * https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html + + A tensor is accepted when: + * 2 ≤ rank ≤ 5 + * dtype ∈ {FP16, BF16, FP8-E4M3FN} + * Every logical size ≥ 2 + * Base pointer 16-byte aligned + * All "outer" dims have 16-byte aligned strides + * The “inner” dim has stride 1 (contiguous) + * For FP8 tensors, inner dim ≥ 32 + """ from torch.utils._triton import has_triton_tma_device from .virtualized import V + def _aligned(expr_bytes: Union[int, sympy.Expr]) -> bool: + return V.graph.sizevars.statically_known_multiple_of(expr_bytes, TMA_ALIGNMENT) + def _is_tma_compatible(x: IRNode) -> bool: - if len(x.get_size()) != 2: + sizes = x.get_size() + strides = x.get_stride() + rank = len(sizes) + dtype = x.get_dtype() + itemsize = dtype.itemsize + + # 2 ≤ rank ≤ 5 + if rank < 2 or rank > 5: return False - dtype = x.get_dtype() + # dtype ∈ {FP16, BF16, FP8-E4M3FN} if dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn): return False - layout = x.get_layout() - transposed = layout.is_transposed() - if not (layout.is_contiguous() or transposed): + # Base pointer 16-byte aligned + if x.get_name() in V.graph.unaligned_buffers: return False - inner_dim = layout.size[1] - if transposed: - inner_dim = layout.size[0] + if add_guards: + sizes_i = V.graph.sizevars.guard_int_seq(sizes) + strides_i = V.graph.sizevars.guard_int_seq(strides) + else: + sizes_i = [V.graph.sizevars.symbolic_hint(s) for s in sizes] + strides_i = [V.graph.sizevars.symbolic_hint(st) for st in strides] + + # Every logical size ≥ 2 + if any(not V.graph.sizevars.statically_known_geq(s, 2) for s in sizes_i): + return False + + # Find the single contiguous (“inner”) dim + inner = [ + i + for i, st in enumerate(strides_i) + if V.graph.sizevars.statically_known_equals(st, 1) + ] + if len(inner) != 1: + return False + inner_idx = inner[0] + + # All "outer" dims must have 16-byte aligned strides + for i, st in enumerate(strides_i): + if i == inner_idx: + continue + if not _aligned(st * itemsize): + return False + + # Inner dim byte width must still be a multiple of 16 B + inner_dim = sizes_i[inner_idx] + if not _aligned(inner_dim * itemsize): + return False - if dtype == torch.float8_e4m3fn and V.graph.sizevars.statically_known_lt( + # FP8 special case: inner ≥ 32 + if dtype == torch.float8_e4m3fn and not V.graph.sizevars.statically_known_geq( inner_dim, 32 ): return False - inner_bytes = inner_dim * dtype.itemsize - return V.graph.sizevars.statically_known_multiple_of(inner_bytes, TMA_ALIGNMENT) + return True return ( config.triton.enable_persistent_tma_matmul @@ -1601,8 +1649,9 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: if not try_import_cutlass(): log.warning( "Failed to import CUTLASS lib. Please check whether " - "_inductor.config.cuda.cutlass_dir is set correctly. " - "Skipping CUTLASS backend for now." + "_inductor.config.cuda.cutlass_dir %s is set correctly. " + "Skipping CUTLASS backend for now.", + config.cuda.cutlass_dir, ) return False return res @@ -1616,20 +1665,15 @@ def _use_cutlass_for_op(op_name: str) -> bool: return op_name.upper() in [x.strip() for x in enabled_ops.split(",")] -decompose_k_threshold = 32 - -# To limit compile time -k_splits_limit = 5 - -# Hand-tuned -default_k_splits = [16, 32, 64, 128, 256] - _IntLike: TypeAlias = Union[int, sympy.Expr] +@functools.cache def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool: from torch._inductor.virtualized import V + decompose_k_threshold = config.triton.decompose_k_threshold + return ( not torch.version.hip and V.graph.sizevars.statically_known_true( @@ -1640,15 +1684,21 @@ def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool: ) and not V.graph.aot_mode # TODO: Support AOTI for decomposeK and not V.graph.cpp_wrapper - and not config.disable_decompose_k ) @functools.cache def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]: + # To limit compile time + k_splits_limit = config.triton.num_decompose_k_splits + + # Hand-tuned + default_k_splits = [16, 32, 64, 128, 256] # If k is a sympy expression, we can't do any splitting if isinstance(k, sympy.Expr) and not k.is_number: return default_k_splits + elif k_splits_limit == 0: + return [] if (isinstance(m, sympy.Expr) and not m.is_number) or ( isinstance(n, sympy.Expr) and not n.is_number @@ -1688,15 +1738,10 @@ def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]: if config.max_autotune_gemm_search_space == "EXHAUSTIVE": return pow_of_2_divisors + mul_of_32_divisors + rest_of_splits - # If the # of power of 2 divisors are greater than k_splits_limit, return all - # This should be ok for compile time, all perfect squares between 128 and min(k / m, k / n) - # should never be a massive amount - if len(pow_of_2_divisors) >= k_splits_limit: - return pow_of_2_divisors - else: - best_splits = pow_of_2_divisors + mul_of_32_divisors + rest_of_splits - # Otherwise, conform results to k_splits_limit - return best_splits[:k_splits_limit] + + best_splits = pow_of_2_divisors + mul_of_32_divisors + rest_of_splits + # Otherwise, conform results to k_splits_limit + return best_splits[:k_splits_limit] @functools.cache @@ -1971,7 +2016,6 @@ def call(self, *args: Any, **kwargs: Any) -> None: self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() ) # Skip all the actual compiling. - nonlocal save_output_code save_output_code(wrapper_code.value) if kernel_code: save_output_code(kernel_code.value) @@ -2178,7 +2222,10 @@ def get_device_tflops(dtype: torch.dtype) -> float: from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops - from torch.testing._internal.common_cuda import SM80OrLater + SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= ( + 8, + 0, + ) assert dtype in (torch.float16, torch.bfloat16, torch.float32) @@ -3072,42 +3119,6 @@ def get_donated_idxs() -> Optional[list[int]]: return None -def set_kernel_post_grad_provenance_tracing( - node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernelOut], - kernel_name: str, - is_extern: bool = False, -) -> None: - from .codegen.simd_kernel_features import DisableReduction, EnableReduction - from .ir import ExternKernelOut - from .virtualized import V - - if is_extern: - assert isinstance(node_schedule, ExternKernelOut) - curr_node_info = ( - V.debug._inductor_triton_kernel_to_post_grad_node_info.setdefault( - kernel_name, [] - ) - ) - curr_node_info.extend( - origin.name - for origin in node_schedule.origins - if origin.name not in curr_node_info - ) - else: - assert isinstance(node_schedule, list) - for snode in node_schedule: - if snode not in (EnableReduction, DisableReduction): - if snode.node is not None: - curr_node_info = V.debug._inductor_triton_kernel_to_post_grad_node_info.setdefault( - kernel_name, [] - ) - curr_node_info.extend( - origin.name - for origin in snode.node.origins - if origin.name not in curr_node_info - ) - - class TritonAttrsDescriptorVersion(enum.Enum): V0_NO_TRITON = 0 V1_COMPILER = 1 # triton.compiler.compiler.AttrsDescriptor diff --git a/torch/_inductor/wrapper_benchmark.py b/torch/_inductor/wrapper_benchmark.py index 815c3a9d1a37..9a527471c8cc 100644 --- a/torch/_inductor/wrapper_benchmark.py +++ b/torch/_inductor/wrapper_benchmark.py @@ -468,13 +468,26 @@ def compiled_module_main( "If None, NCU will use '--set full'." ), ) + parser.add_argument( + "--times", + type=int, + default=10, + help="Number of times to run each benchmark iteration", + ) + parser.add_argument( + "--repeat", + type=int, + default=10, + help="Number of repetitions of each benchmark run", + ) + args = parser.parse_args() if args.benchmark_kernels: benchmark_all_kernels(benchmark_name, args.benchmark_all_configs) else: - times = 10 - repeat = 10 + times = args.times + repeat = args.repeat if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() diff --git a/torch/_logging/__init__.py b/torch/_logging/__init__.py index 6e28319cddc1..d0fdebb23bde 100644 --- a/torch/_logging/__init__.py +++ b/torch/_logging/__init__.py @@ -12,6 +12,7 @@ dtrace_structured, get_structured_logging_overhead, getArtifactLogger, + hide_warnings, LazyString, set_logs, trace_structured, diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index 11185e334f5d..ffd3160b47ee 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import contextlib import functools import hashlib import importlib.util @@ -12,6 +13,7 @@ import sys import tempfile import time +import warnings from collections import defaultdict from dataclasses import dataclass, field from typing import Any, Callable, Generic, Optional, Union @@ -1156,6 +1158,45 @@ def warning_once(logger_obj, *args, **kwargs) -> None: logger_obj.warning(*args, **kwargs) +def safe_grad_filter(message, category, filename, lineno, file=None, line=None) -> bool: + return "The .grad attribute of a Tensor" not in str(message) + + +def user_warning_filter( + message, category, filename, lineno, file=None, line=None +) -> bool: + return not category == UserWarning + + +@contextlib.contextmanager +def hide_warnings(filter_fn=lambda *args, **kwargs: True): + """ + A context manager that temporarily suppresses warnings, + using public API: https://docs.python.org/3/library/warnings.html#warnings.showwarning. + + Useful to hide warnings without mutating warnings module state, see: + https://github.com/pytorch/pytorch/issues/128427#issuecomment-2161496162. + + NOTE: Warnings issued under this context will still be cached in the __warningregistry__ + and count towards the once/default rule. So you should NEVER use this on a user-land function. + + Filter must implement the showwarning API: + def filter_fn(message, category, filename, lineno, file=None, line=None) -> bool: + return True # show this warning entry + """ + prior = warnings.showwarning + + def _showwarning(*args, **kwargs): + if filter_fn(*args, **kwargs): + prior(*args, **kwargs) + + try: + warnings.showwarning = _showwarning + yield + finally: + warnings.showwarning = prior + + class LazyString(Generic[_P]): def __init__( self, func: Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index acb7ab2e5a05..ae87e0e17fb3 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -2789,7 +2789,13 @@ def meta_qlinear_pointwise( output_shape = list(x.shape) # The weight has been transposed during the qlinear weight prepack process. output_shape[-1] = w.shape[1] - assert output_dtype in [torch.float32, torch.bfloat16, torch.int8, torch.uint8] + assert output_dtype in [ + torch.float32, + torch.bfloat16, + torch.int8, + torch.uint8, + torch.float8_e4m3fn, + ] out = x.new_empty(output_shape, dtype=output_dtype) return out @@ -2820,7 +2826,13 @@ def meta_qlinear_pointwise_binary( output_shape = list(x.shape) # The weight has been transposed during the qlinear weight prepack process. output_shape[-1] = w.shape[1] - assert output_dtype in [torch.float32, torch.bfloat16, torch.uint8, torch.int8] + assert output_dtype in [ + torch.float32, + torch.bfloat16, + torch.uint8, + torch.int8, + torch.float8_e4m3fn, + ] out = x.new_empty(output_shape, dtype=output_dtype) return out @@ -7490,18 +7502,20 @@ def _create_grouped_mm_output_tensor(mat1, mat2, offs, out_dtype): out_size = [offs.size(0), mat1.size(0), mat2.size(1)] else: torch._check( - offs.size(0) == mat2.size(0), "matrix batch sizes have to match" + offs.size(0) == mat2.size(0), lambda: "matrix batch sizes have to match" ) out_size = [mat1.size(0), mat2.size(-1)] else: if mat2_is_2d: torch._check( - offs.size(0) == mat1.size(0), "matrix batch sizes have to match" + offs.size(0) == mat1.size(0), lambda: "matrix batch sizes have to match" ) out_size = [mat1.size(1), mat2.size(1)] else: # regular bmm - torch._check(mat1.size(0) == mat2.size(0), "batched dimension has to match") + torch._check( + mat1.size(0) == mat2.size(0), lambda: "batched dimension has to match" + ) out_size = [mat1.size(0), mat1.size(1), mat2.size(-1)] out_dtype = out_dtype or mat1.dtype diff --git a/torch/_numpy/_ndarray.py b/torch/_numpy/_ndarray.py index 05e82300145d..f192a39dd029 100644 --- a/torch/_numpy/_ndarray.py +++ b/torch/_numpy/_ndarray.py @@ -169,17 +169,22 @@ def _upcast_int_indices(index): return index +def _has_advanced_indexing(index): + """Check if there's any advanced indexing""" + return any( + isinstance(idx, (Sequence, bool)) + or (isinstance(idx, torch.Tensor) and (idx.dtype == torch.bool or idx.ndim > 0)) + for idx in index + ) + + def _numpy_compatible_indexing(index): """Convert scalar indices to lists when advanced indexing is present for NumPy compatibility.""" if not isinstance(index, tuple): index = (index,) # Check if there's any advanced indexing (sequences, booleans, or tensors) - has_advanced = any( - isinstance(idx, (Sequence, bool)) - or (isinstance(idx, torch.Tensor) and (idx.dtype == torch.bool or idx.ndim > 0)) - for idx in index - ) + has_advanced = _has_advanced_indexing(index) if not has_advanced: return index @@ -206,6 +211,84 @@ def _numpy_compatible_indexing(index): return tuple(converted) +def _get_bool_depth(s): + """Returns the depth of a boolean sequence/tensor""" + if isinstance(s, bool): + return True, 0 + if isinstance(s, torch.Tensor) and s.dtype == torch.bool: + return True, s.ndim + if not (isinstance(s, Sequence) and s and s[0] != s): + return False, 0 + is_bool, depth = _get_bool_depth(s[0]) + return is_bool, depth + 1 + + +def _numpy_empty_ellipsis_patch(index, tensor_ndim): + """ + Patch for NumPy-compatible ellipsis behavior when ellipsis doesn't match any dimensions. + + In NumPy, when an ellipsis (...) doesn't actually match any dimensions of the input array, + it still acts as a separator between advanced indices. PyTorch doesn't have this behavior. + + This function detects when we have: + 1. Advanced indexing on both sides of an ellipsis + 2. The ellipsis doesn't actually match any dimensions + """ + if not isinstance(index, tuple): + index = (index,) + + # Find ellipsis position + ellipsis_pos = None + for i, idx in enumerate(index): + if idx is Ellipsis: + ellipsis_pos = i + break + + # If no ellipsis, no patch needed + if ellipsis_pos is None: + return index, lambda x: x, lambda x: x + + # Count non-ellipsis dimensions consumed by the index + consumed_dims = 0 + for idx in index: + is_bool, depth = _get_bool_depth(idx) + if is_bool: + consumed_dims += depth + elif idx is Ellipsis or idx is None: + continue + else: + consumed_dims += 1 + + # Calculate how many dimensions the ellipsis should match + ellipsis_dims = tensor_ndim - consumed_dims + + # Check if ellipsis doesn't match any dimensions + if ellipsis_dims == 0: + # Check if we have advanced indexing on both sides of ellipsis + left_advanced = _has_advanced_indexing(index[:ellipsis_pos]) + right_advanced = _has_advanced_indexing(index[ellipsis_pos + 1 :]) + + if left_advanced and right_advanced: + # This is the case where NumPy and PyTorch differ + # We need to ensure the advanced indices are treated as separated + new_index = index[:ellipsis_pos] + (None,) + index[ellipsis_pos + 1 :] + end_ndims = 1 + sum( + 1 for idx in index[ellipsis_pos + 1 :] if isinstance(idx, slice) + ) + + def squeeze_fn(x): + return x.squeeze(-end_ndims) + + def unsqueeze_fn(x): + if isinstance(x, torch.Tensor) and x.ndim >= end_ndims: + return x.unsqueeze(-end_ndims) + return x + + return new_index, squeeze_fn, unsqueeze_fn + + return index, lambda x: x, lambda x: x + + # Used to indicate that a parameter is unspecified (as opposed to explicitly # `None`) class _Unspecified: @@ -507,19 +590,23 @@ def neg_step(i, s): index = _upcast_int_indices(index) # Apply NumPy-compatible indexing conversion index = _numpy_compatible_indexing(index) - return ndarray(tensor.__getitem__(index)) + # Apply NumPy-compatible empty ellipsis behavior + index, maybe_squeeze, _ = _numpy_empty_ellipsis_patch(index, tensor.ndim) + return maybe_squeeze(ndarray(tensor.__getitem__(index))) def __setitem__(self, index, value): index = _util.ndarrays_to_tensors(index) index = _upcast_int_indices(index) # Apply NumPy-compatible indexing conversion index = _numpy_compatible_indexing(index) + # Apply NumPy-compatible empty ellipsis behavior + index, _, maybe_unsqueeze = _numpy_empty_ellipsis_patch(index, self.tensor.ndim) if not _dtypes_impl.is_scalar(value): value = normalize_array_like(value) value = _util.cast_if_needed(value, self.tensor.dtype) - return self.tensor.__setitem__(index, value) + return self.tensor.__setitem__(index, maybe_unsqueeze(value)) take = _funcs.take put = _funcs.put diff --git a/torch/_ops.py b/torch/_ops.py index 600f6d9e1ada..fecfebaeaa53 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -415,10 +415,19 @@ def check_overloaded(arg): # TODO(rzou): we should support torch_dispatch calling convention too. result = handler(mode, *args, **kwargs) else: - raise NotImplementedError( - f"There was no rule registered for HOP {self._name} and mode {curr_mode}. " - f"We recommend filing an issue." - ) + if curr_mode.supports_higher_order_operators: + with _pop_mode_temporarily() as mode: + return curr_mode.__torch_dispatch__(self, [], args, kwargs) + else: + raise NotImplementedError( + f"There was no rule registered for HigherOrderOperator {self._name} and mode {curr_mode}." + f"Hint: set {curr_mode}'s supports_higher_order_operators to True." + f" This causes all higher order operators to pass through {curr_mode}'s __torch_dispatch__," + f" so handle them accordingly by" + f" adding support for HigerOrderOperators (in this case, {self._name}) in" + f" {curr_mode}.__torch_dispatch__ or" + f" returning NotImplemented when not supported." + ) if result is not NotImplemented: return result @@ -457,10 +466,12 @@ def check_overloaded(arg): # All handlers returned NotImplemented raise TypeError( - f"Multiple dispatch failed for {self._name}. There was no registered that " - f"did not return NotImplemented. Use HOP.py_impl to register some. " - f"Tried mode: {curr_mode}) and subclasses: " - f"{[type(a) for a in overloaded_args]}" + f"HigherOrderOperator '{self._name}' is not supported for the given input types. " + f"This typically happens when using custom tensor types or dispatch modes that don't " + f"have implementations for this operation.\n\n" + f"Current mode: {curr_mode}\n" + f"Input types: {[type(a).__name__ for a in overloaded_args]}\n\n" + f"To fix this, can add support for '{self._name}' in {curr_mode}'s __torch_dispatch__\n" ) functionality_key = torch._C._to_functionality_key(dispatch_key) # type: ignore[attr-defined] @@ -1475,7 +1486,10 @@ def load_library(self, path): # Import the shared library into the process, thus running its # static (global) initialization code in order to register custom # operators with the JIT. - ctypes.CDLL(path) + try: + ctypes.CDLL(path) + except Exception as e: + raise OSError(f"Could not load this library: {path}") from e self.loaded_libraries.add(path) diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 5d24eb42090d..03a3fd91831b 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -5,7 +5,6 @@ import functools import threading import typing -import warnings import weakref from abc import abstractmethod from contextlib import AbstractContextManager, contextmanager @@ -81,8 +80,7 @@ def safe_is_leaf(t: Union[MetaTensorDesc, torch.Tensor]) -> bool: def safe_grad(t: _TensorLikeT) -> Optional[_TensorLikeT]: - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") + with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter): return t.grad diff --git a/torch/_tensor.py b/torch/_tensor.py index 652cd33a0353..dd9d987eea66 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1659,7 +1659,14 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): __torch_dispatch__ = _C._disabled_torch_dispatch_impl - def __dlpack__(self, *, stream=None, max_version=None): + def __dlpack__( + self, + *, + stream: Optional[Any] = None, + max_version: Optional[tuple[int, int]] = None, + dl_device: Optional[tuple[enum.IntEnum, int]] = None, + copy: Optional[bool] = None, + ): """ Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_ of the current tensor to be exported to other libraries. @@ -1670,22 +1677,31 @@ def __dlpack__(self, *, stream=None, max_version=None): Args: stream (integer or None): An optional Python integer representing a - pointer to a CUDA stream. The current stream is synchronized with - this stream before the capsule is created, and since the capsule - shares its storage with the tensor this make it safe to access from - both streams. If None or -1 is passed then no synchronization is performed. - If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for - synchronization. + pointer to a CUDA stream. The current stream is synchronized with + this stream before the capsule is created, and since the capsule + shares its storage with the tensor this make it safe to access from + both streams. If None or -1 is passed then no synchronization is performed. + If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for + synchronization. max_version (tuple[int, int] or None): An optional Python tuple with - 2 integers, representing the maximum version the caller supports. If - None (default), PyTorch will fallback to DLPack 0.8. + 2 integers, representing the maximum version the caller supports. If + None (default), PyTorch will fallback to DLPack 0.8. + + dl_device (tuple[DLDeviceType, int] or None): An optional tuple specifying + in which device the exported DLPack capsule should be on. If None (default), + the exported DLPack capsule will be on the same device as ``self``. + + copy (bool or None): An optional boolean indicating whether or not to copy + ``self``. If None, PyTorch will copy only if necessary. """ if has_torch_function_unary(self): args = (self,) kwargs = { "stream": stream, "max_version": max_version, + "dl_device": dl_device, + "copy": copy, } return handle_torch_function(Tensor.__dlpack__, (self,), *args, **kwargs) @@ -1693,37 +1709,59 @@ def __dlpack__(self, *, stream=None, max_version=None): # so we prohibit exporting tensors that would lose their properties like # requires_grad and having the conjugate bit set. if self.requires_grad: - raise RuntimeError( + raise BufferError( "Can't export tensors that require gradient, use tensor.detach()" ) if self.is_conj(): - raise RuntimeError("Can't export tensors with the conjugate bit set") + raise BufferError("Can't export tensors with the conjugate bit set") if self.layout != torch.strided: - raise RuntimeError( + raise BufferError( "Can't export tensors with layout other than torch.strided" ) + if ( + self.device.type == "cuda" + and self.device.index != torch.cuda.current_device() + ): + raise BufferError( + "Can't export tensors on a different CUDA device index. " + f"Expected: {self.device.index}. " + f"Current device: {torch.cuda.current_device()}." + ) + if stream is not None and type(stream) is not int: # Stream pointers in CUDA/ROCm are uniquely numbered and can # be retrieved from their integer value. raise TypeError("stream must be ``int`` or ``none``") - elif stream is not None and stream != -1: - if self.device.type == "cuda": - # NB: This logic handles the special case values for default - # streams and must be kept in sync with from_dlpack in - # torch/utils/dlpack.py - if stream == 1 and torch.version.hip is None: - stream = torch.cuda.default_stream() - elif stream == 0 and torch.version.hip is not None: - stream = torch.cuda.default_stream() - else: - stream = torch.cuda.ExternalStream(stream) - # Only synchronize on different streams - sync_stream = torch.cuda.current_stream() - if stream != sync_stream: - event = torch.cuda.Event() - event.record(sync_stream) - stream.wait_event(event) + elif self.device.type == "cuda" and stream != -1: + # NB: This logic handles the special case values for default + # streams and must be kept in sync with from_dlpack in + # torch/utils/dlpack.py + is_rocm = torch.version.hip is not None + is_cuda = not is_rocm + + if stream is None or (is_rocm and stream == 0) or (is_cuda and stream == 1): + stream = torch.cuda.default_stream() + else: + if is_cuda and stream == 2: + raise BufferError("per-thread default stream is not supported.") + + device_str = "CUDA" if is_cuda else "ROCm" + assert (is_cuda and stream != 0) or ( + is_rocm and stream not in (1, 2) + ), f"unsupported stream on {device_str}: {stream}." + + stream = torch.cuda.ExternalStream(stream) + + # Only synchronize on different streams + current_stream = torch.cuda.current_stream() + if stream != current_stream: + event = torch.cuda.Event() + event.record(current_stream) + stream.wait_event(event) + elif self.device.type == "cpu": + assert stream is None, "stream should be None on cpu." + if self.device.type == "xla": import torch_xla import torch_xla.utils.dlpack as xla_dlpack @@ -1741,9 +1779,9 @@ def __dlpack__(self, *, stream=None, max_version=None): if max_version is None or max_version[0] < 1: # Fallback to the old, unversioned variant. - return torch.to_dlpack(self) + return _C._to_dlpack(self, dl_device=dl_device, copy=copy) - return _C._to_dlpack_versioned(self) + return _C._to_dlpack_versioned(self, dl_device=dl_device, copy=copy) def __dlpack_device__(self) -> tuple[enum.IntEnum, int]: if has_torch_function_unary(self): diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 9a96f3c097a5..8c12d4d68930 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -5128,7 +5128,7 @@ def merge_dicts(*dicts): If :attr:`bins` is a sequence of N 1D tensors, it explicitly specifies the N sequences of bin edges. Each 1D tensor should contain a strictly increasing sequence with at least one element. A sequence of K bin edges defines K-1 bins, explicitly specifying -the left and right edges of all bins. Every bin is exclusive of its left edge. Only +the left and right edges of all bins. Every bin is inclusive of its left edge. Only the rightmost bin is inclusive of its right edge. If :attr:`bins` is a sequence of N ints, it specifies the number of equal-width bins @@ -7605,7 +7605,7 @@ def merge_dicts(*dicts): Args: {input} - other (Tensor or Number) - the tensor or number to multiply input by. + other (Tensor or Number): the tensor or number to multiply input by. Keyword args: {out} @@ -8948,7 +8948,7 @@ def merge_dicts(*dicts): Keyword args: {generator} {out} - dtype (`torch.dtype`, optional) - the desired data type of returned tensor. Default: if ``None``, + dtype (torch.dtype, optional): the desired data type of returned tensor. Default: if ``None``, this function returns a tensor with dtype ``torch.int64``. {layout} {device} diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index ffc1792fd23f..f50b9d6cd137 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs +import sys from typing import Callable, Optional, Union import torch @@ -32,8 +33,16 @@ # ensure __module__ is set correctly for public APIs -ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase] -ObserverOrFakeQuantize.__module__ = "torch.ao.quantization" +if sys.version_info < (3, 12): + ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase] + ObserverOrFakeQuantize.__module__ = "torch.ao.quantization" +else: + from typing import TypeAliasType + + ObserverOrFakeQuantize = TypeAliasType( + "ObserverOrFakeQuantize", Union[ObserverBase, FakeQuantizeBase] + ) + for _f in [ compare_results, extract_results_from_loggers, diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index efee5302ad42..94dfdb8c7626 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import copy +import sys import warnings from collections import namedtuple from typing import Any, Optional, Union @@ -567,8 +568,13 @@ def _assert_valid_qconfig(qconfig: Optional[QConfig], mod: torch.nn.Module) -> N ) -QConfigAny = Optional[QConfig] -QConfigAny.__module__ = "torch.ao.quantization.qconfig" +if sys.version_info < (3, 12): + QConfigAny = Optional[QConfig] + QConfigAny.__module__ = "torch.ao.quantization.qconfig" +else: + from typing import TypeAliasType + + QConfigAny = TypeAliasType("QConfigAny", Optional[QConfig]) def _add_module_to_qconfig_obs_ctr( diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py index feae45df3b86..e93cd3fdb7cb 100644 --- a/torch/ao/quantization/utils.py +++ b/torch/ao/quantization/utils.py @@ -4,6 +4,7 @@ """ import functools +import sys import warnings from collections import OrderedDict from inspect import getfullargspec, signature @@ -15,8 +16,16 @@ from torch.nn.utils.parametrize import is_parametrized -NodePattern = Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any] -NodePattern.__module__ = "torch.ao.quantization.utils" +if sys.version_info < (3, 12): + NodePattern = Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any] + NodePattern.__module__ = "torch.ao.quantization.utils" +else: + from typing import TypeAliasType + + NodePattern = TypeAliasType( + "NodePattern", Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any] + ) + # This is the Quantizer class instance from torch/quantization/fx/quantize.py. # Define separately to prevent circular imports. @@ -28,10 +37,27 @@ # Type for fusion patterns, it can be more complicated than the following actually, # see pattern.md for docs # TODO: not sure if typing supports recursive data types -Pattern = Union[ - Callable, tuple[Callable, Callable], tuple[Callable, tuple[Callable, Callable]], Any -] -Pattern.__module__ = "torch.ao.quantization.utils" + +if sys.version_info < (3, 12): + Pattern = Union[ + Callable, + tuple[Callable, Callable], + tuple[Callable, tuple[Callable, Callable]], + Any, + ] + Pattern.__module__ = "torch.ao.quantization.utils" +else: + from typing import TypeAliasType + + Pattern = TypeAliasType( + "Pattern", + Union[ + Callable, + tuple[Callable, Callable], + tuple[Callable, tuple[Callable, Callable]], + Any, + ], + ) # TODO: maybe rename this to MatchInputNode diff --git a/torch/autograd/function.py b/torch/autograd/function.py index b8036a5235b9..ac3aad9f93b5 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -4,8 +4,8 @@ import itertools import warnings from collections import OrderedDict -from typing import Any, Optional -from typing_extensions import deprecated +from typing import Any, Callable, Optional, TypeVar +from typing_extensions import Concatenate, deprecated, ParamSpec import torch import torch._C as _C @@ -29,6 +29,10 @@ # This is incremented in FunctionMeta during class definition AUTOGRAD_FUNCTION_COUNTER = itertools.count() +_T = TypeVar("_T") +_R = TypeVar("_R") +_P = ParamSpec("_P") + # Formerly known as: _ContextMethodMixin class FunctionCtx: @@ -595,11 +599,13 @@ def _is_setup_context_defined(fn): return fn != _SingleLevelFunction.setup_context -def once_differentiable(fn): +def once_differentiable( + fn: Callable[Concatenate[_T, _P], _R], +) -> Callable[Concatenate[_T, _P], _R]: @functools.wraps(fn) - def wrapper(ctx, *args): + def wrapper(ctx: _T, *args: _P.args, **kwargs: _P.kwargs) -> _R: with torch.no_grad(): - outputs = fn(ctx, *args) + outputs = fn(ctx, *args, **kwargs) if not torch.is_grad_enabled(): return outputs @@ -620,12 +626,14 @@ def wrapper(ctx, *args): return outputs if not isinstance(outputs, tuple): - outputs = (outputs,) + outputs_ = (outputs,) + else: + outputs_ = outputs err_fn = _functions.DelayedError( b"trying to differentiate twice a function that was marked " b"with @once_differentiable", - len(outputs), + len(outputs_), ) # Create aliases of each output that has requires_grad=True. We need @@ -637,7 +645,7 @@ def fake_requires_grad(var): var.requires_grad = True return var - return err_fn(*[fake_requires_grad(v) for v in outputs]) + return err_fn(*[fake_requires_grad(v) for v in outputs_]) # type: ignore[return-value] return wrapper diff --git a/torch/backends/__init__.py b/torch/backends/__init__.py index de194b12d02c..c5ab7640386a 100644 --- a/torch/backends/__init__.py +++ b/torch/backends/__init__.py @@ -136,5 +136,6 @@ def __init__(self, m, name): mps as mps, nnpack as nnpack, openmp as openmp, + opt_einsum as opt_einsum, quantized as quantized, ) diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index 578b56e504e2..163c25f12dbc 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar +from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union from typing_extensions import ParamSpec import torch @@ -39,6 +39,8 @@ _P = ParamSpec("_P") _R = TypeVar("_R") +FuncType = Callable[..., Any] +F = TypeVar("F", bound=FuncType) def compile(*args, **kwargs): @@ -252,7 +254,10 @@ def disable(fn=None, recursive=True, *, reason=None): def set_stance( - stance: str = "default", *, skip_guard_eval_unsafe=False, force_backend=None + stance: str = "default", + *, + skip_guard_eval_unsafe: bool = False, + force_backend: Union[str, Callable[..., Any], None] = None, ): """ Set the current stance of the compiler. @@ -355,7 +360,7 @@ def set_enable_guard_collectives(enabled: bool): from torch._dynamo.eval_frame import guard_collectives_hook if enabled: - return set_guard_complete_hook(guard_collectives_hook) is not None + return set_guard_complete_hook(guard_collectives_hook) is not None # type: ignore[arg-type] else: return set_guard_complete_hook(None) is not None diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index 8df6ea24a4bb..60a7bb644df0 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -74,6 +74,7 @@ inline void PyErr_SetString(PyObject* type, const std::string& message) { _CATCH_GENERIC_ERROR(TypeError, PyExc_TypeError, retstmnt) \ _CATCH_GENERIC_ERROR( \ NotImplementedError, PyExc_NotImplementedError, retstmnt) \ + _CATCH_GENERIC_ERROR(BufferError, PyExc_BufferError, retstmnt) \ _CATCH_GENERIC_ERROR(SyntaxError, PyExc_SyntaxError, retstmnt) \ _CATCH_GENERIC_ERROR(LinAlgError, THPException_LinAlgError, retstmnt) \ _CATCH_GENERIC_ERROR( \ diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 5a0f4a59abe3..15efa62ae978 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -609,25 +609,56 @@ void DLPack_Capsule_Destructor(PyObject* data) { } template -PyObject* THPModule_toDLPackImpl(PyObject* _unused, PyObject* data) { +PyObject* THPModule_toDLPackImpl( + PyObject* self, + PyObject* args, + PyObject* kwargs) { HANDLE_TH_ERRORS - TORCH_CHECK(THPVariable_Check(data), "data must be a Tensor"); - auto tensor = at::DLPackTraits::toDLPack(THPVariable_Unpack(data)); + static torch::PythonArgParser parser( + {"_to_dlpack(Tensor data, *, IntArrayRef? dl_device=None, bool? copy=None)"}); + torch::ParsedArgs<3> parsed_args{}; + auto r = parser.parse(args, kwargs, parsed_args); + + TORCH_INTERNAL_ASSERT(r.idx == 0); + + auto data = r.tensor(0); + auto dl_device = r.intlist(1); + auto copy = r.toBoolOptional(2); + + // Parse the int list into a tuple. + std::optional optional_dl_device; + + if (!dl_device.empty()) { + TORCH_CHECK( + dl_device.size() == 2, + "dl_device must be either None or a tuple of ints"); + optional_dl_device = DLDevice{ + static_cast(dl_device[0]), + static_cast(dl_device[1])}; + } + + auto tensor = at::DLPackTraits::toDLPack( + at::maybeCopyTensor(data, optional_dl_device, copy)); return PyCapsule_New( tensor, at::DLPackTraits::capsule, DLPack_Capsule_Destructor); + END_HANDLE_TH_ERRORS } } // namespace -static PyObject* THPModule_toDLPack(PyObject* _unused, PyObject* data) { - return THPModule_toDLPackImpl(_unused, data); +static PyObject* THPModule_toDLPack( + PyObject* self, + PyObject* args, + PyObject* kwargs) { + return THPModule_toDLPackImpl(self, args, kwargs); } static PyObject* THPModule_toDLPackVersioned( - PyObject* _unused, - PyObject* data) { - return THPModule_toDLPackImpl(_unused, data); + PyObject* self, + PyObject* args, + PyObject* kwargs) { + return THPModule_toDLPackImpl(self, args, kwargs); } static PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) { @@ -638,6 +669,28 @@ static PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) { END_HANDLE_TH_ERRORS } +static PyObject* THPModule_torchDeviceToDLDevice( + PyObject* _unused, + PyObject* data) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPDevice_Check(data), + "torchDeviceToDLDevice: expected torch.device argument."); + auto device = reinterpret_cast(data)->device; + auto dl_device = at::torchDeviceToDLDevice(device); + + auto tuple = PyTuple_New(2); + if (!tuple) { + throw python_error(); + } + + PyTuple_SET_ITEM(tuple, 0, THPUtils_packInt64(dl_device.device_type)); + PyTuple_SET_ITEM(tuple, 1, THPUtils_packInt64(dl_device.device_id)); + + return tuple; + END_HANDLE_TH_ERRORS +} + static PyObject* THModule_getCppBacktrace(PyObject* _unused, PyObject* args) { HANDLE_TH_ERRORS size_t frames_to_skip = 0; @@ -1689,9 +1742,19 @@ static std::initializer_list TorchMethods = { THPModule_are_vmap_fallback_warnings_enabled, METH_NOARGS, nullptr}, - {"_to_dlpack", THPModule_toDLPack, METH_O, nullptr}, - {"_to_dlpack_versioned", THPModule_toDLPackVersioned, METH_O, nullptr}, + {"_to_dlpack", + castPyCFunctionWithKeywords(THPModule_toDLPack), + METH_VARARGS | METH_KEYWORDS, + nullptr}, + {"_to_dlpack_versioned", + castPyCFunctionWithKeywords(THPModule_toDLPackVersioned), + METH_VARARGS | METH_KEYWORDS, + nullptr}, {"_from_dlpack", THPModule_fromDLPack, METH_O, nullptr}, + {"_torchDeviceToDLDevice", + THPModule_torchDeviceToDLDevice, + METH_O, + nullptr}, {"_get_cpp_backtrace", THModule_getCppBacktrace, METH_VARARGS, nullptr}, {"_rename_privateuse1_backend", THModule_rename_privateuse1_backend, diff --git a/torch/csrc/PyInterpreterHooks.cpp b/torch/csrc/PyInterpreterHooks.cpp new file mode 100644 index 000000000000..5e064493fd59 --- /dev/null +++ b/torch/csrc/PyInterpreterHooks.cpp @@ -0,0 +1,20 @@ +#include +#include + +namespace torch::detail { + +PyInterpreterHooks::PyInterpreterHooks(c10::impl::PyInterpreterHooksArgs) {} + +c10::impl::PyInterpreter* PyInterpreterHooks::getPyInterpreter() const { + // Delegate to the existing implementation + return ::getPyInterpreter(); +} + +} // namespace torch::detail + +// Sigh, the registry doesn't support namespaces :( +using c10::impl::PyInterpreterHooksRegistry; +using c10::impl::RegistererPyInterpreterHooksRegistry; +using PyInterpreterHooks = torch::detail::PyInterpreterHooks; +// Register the implementation +REGISTER_PYTHON_HOOKS(PyInterpreterHooks) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 908a980cfee9..8e13d4267edb 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -5023,6 +5023,103 @@ std::tuple layer_norm_double_backward( return std::tuple{gI, gG, ggO}; } +std::tuple infinitely_differentiable_native_rms_norm_backward( + const Tensor& dY, + const Tensor& drstd, + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& rstd, + const std::optional& weight_opt, + std::array grad_input_mask) { + c10::MaybeOwned weight_maybe_owned = + at::borrow_from_optional_tensor(weight_opt); + const Tensor& weight = *weight_maybe_owned; + + const auto input_shape = input.sizes(); + const auto input_ndim = input.dim(); + const int normalized_ndim = normalized_shape.size(); + const int axis = input_ndim - normalized_ndim; + + int64_t N_rms = 1; + for (int i = 0; i < normalized_ndim; ++i) { + N_rms *= input_shape[axis + i]; + } + + Tensor dX; + Tensor dgamma; + + std::vector rstd_view_shape = rstd.sizes().vec(); + for (int i = 0; + i < std::max(static_cast(normalized_ndim - rstd.dim()), 0); + ++i) { + rstd_view_shape.push_back(1); + } + Tensor rstd_broadcast = rstd.view(rstd_view_shape); + Tensor rstd_pow3 = rstd_broadcast.pow(3); + Tensor grad_x_hat; + + if (dY.defined()) { + if (weight.defined()) { + grad_x_hat = dY * weight; + } else { + grad_x_hat = dY; + } + } + + if (grad_input_mask[0]) { + Tensor dX_from_dY_path; + Tensor dX_from_drstd_path; + + std::vector inner_sum_dims; + inner_sum_dims.reserve(normalized_ndim); + for (int i = 0; i < normalized_ndim; ++i) { + inner_sum_dims.push_back(axis + i); + } + + if (dY.defined() && grad_x_hat.defined()) { + Tensor sum_input_times_grad_x_hat = + sum(input * grad_x_hat, inner_sum_dims, /*keepdim=*/true); + dX_from_dY_path = rstd_broadcast * grad_x_hat - + (input * rstd_pow3 / static_cast(N_rms)) * + sum_input_times_grad_x_hat; + } + + if (drstd.defined()) { + Tensor drstd_broadcast = drstd.view(rstd_view_shape); + dX_from_drstd_path = + -(input * rstd_pow3 / static_cast(N_rms)) * drstd_broadcast; + } + + if (dX_from_dY_path.defined() && dX_from_drstd_path.defined()) { + dX = dX_from_dY_path + dX_from_drstd_path; + } else if (dX_from_dY_path.defined()) { + dX = dX_from_dY_path; + } else if (dX_from_drstd_path.defined()) { + dX = dX_from_drstd_path; + } + } + + if (grad_input_mask[1] && weight.defined()) { + if (dY.defined()) { + Tensor x_hat = input * rstd_broadcast; + Tensor dgamma_full_shape = dY * x_hat; + + if (axis > 0) { + std::vector outer_sum_dims; + outer_sum_dims.reserve(axis); + for (int i = 0; i < axis; ++i) { + outer_sum_dims.push_back(i); + } + dgamma = sum(dgamma_full_shape, outer_sum_dims, /*keepdim=*/false); + } else { + dgamma = dgamma_full_shape; + } + } + } + + return std::make_tuple(dX, dgamma); +} + std::tuple infinitely_differentiable_native_group_norm_backward( const Tensor& dY, @@ -6377,6 +6474,98 @@ Tensor layer_norm_jvp( bias_t.defined() ? bias_t.view(view_size_affine) : bias_t); } +Tensor rms_norm_jvp( + const Tensor& input_p, + const Tensor& input_t, + const Tensor& weight_p, + const Tensor& weight_t, + const Tensor& saved_rstd, + IntArrayRef normalized_shape) { + auto dims = std::vector{}; + auto view_size = input_t.sizes().vec(); + auto view_size_affine = input_t.sizes().vec(); + + int64_t numel = 1; + for (const auto i : c10::irange(view_size.size())) { + if (i < view_size.size() - normalized_shape.size()) { + view_size_affine[i] = 1; + } else { + numel *= input_t.size(static_cast(i)); + view_size[i] = 1; + dims.push_back(static_cast(i)); + } + } + + auto rstd_p = saved_rstd.view(view_size); + + Tensor rstd_t; + if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || + input_t._is_zerotensor()) { + rstd_t = -rstd_p.pow(3) * (input_t) * (input_p); + } else { + rstd_t = input_t * input_p; + rstd_t *= -rstd_p.pow(3); + } + rstd_t = rstd_t.sum(dims, true); + rstd_t /= numel; + + Tensor result_t; + if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || + input_t._is_zerotensor()) { + result_t = (input_t)*rstd_p + (input_p)*rstd_t; + } else { + result_t = input_t * rstd_p; + auto temp = input_p * rstd_t; + result_t += temp; + } + + std::optional result_p = std::nullopt; + if (weight_p.defined()) { + result_p = std::optional(input_p * rstd_p); + } + + return _affine_jvp( + result_p, + result_t, + weight_p.defined() ? weight_p.view(view_size_affine) : weight_p, + weight_t.defined() ? weight_t.view(view_size_affine) : weight_t, + Tensor()); +} + +Tensor rms_norm_rstd_jvp( + const Tensor& input_p, + const Tensor& input_t, + const Tensor& saved_rstd, + IntArrayRef normalized_shape) { + auto dims = std::vector{}; + auto view_size = input_t.sizes().vec(); + auto view_size_affine = input_t.sizes().vec(); + + int64_t numel = 1; + for (const auto i : c10::irange(view_size.size())) { + if (i < view_size.size() - normalized_shape.size()) { + view_size_affine[i] = 1; + } else { + numel *= input_t.size(static_cast(i)); + view_size[i] = 1; + dims.push_back(static_cast(i)); + } + } + + auto rstd_p = saved_rstd.view(view_size); + Tensor rstd_t; + if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || + input_t._is_zerotensor()) { + rstd_t = -rstd_p.pow(3) * (input_t) * (input_p); + } else { + rstd_t = input_t * input_p; + rstd_t *= -rstd_p.pow(3); + } + rstd_t = rstd_t.sum(dims, true); + rstd_t /= numel; + return rstd_t; +} + Tensor group_norm_jvp( const Tensor& input_p, const Tensor& input_t, diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 0b659973ec34..96864e165a95 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -826,6 +826,15 @@ std::tuple layer_norm_double_backward( c10::SymIntArrayRef normalized_shape, std::array output_mask); +std::tuple infinitely_differentiable_native_rms_norm_backward( + const Tensor& dY, + const Tensor& drstd, + const Tensor& input, + IntArrayRef normalized_shape, + const Tensor& rstd, + const std::optional& weight_opt, + std::array grad_input_mask); + std::tuple householder_product_backward( const Tensor& grad, const Tensor& result, @@ -965,6 +974,20 @@ Tensor layer_norm_jvp( const Tensor& saved_invstd, c10::SymIntArrayRef normalized_shape); +Tensor rms_norm_jvp( + const Tensor& input_p, + const Tensor& input_t, + const Tensor& weight_p, + const Tensor& weight_t, + const Tensor& saved_rstd, + IntArrayRef normalized_shape); + +Tensor rms_norm_rstd_jvp( + const Tensor& input_p, + const Tensor& input_t, + const Tensor& saved_rstd, + IntArrayRef normalized_shape); + Tensor group_norm_jvp( const Tensor& input_p, const Tensor& input_t, diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index a02e5dda5d99..69e8831936b0 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -1252,7 +1252,6 @@ int PythonTracer::pyProfileFn( local_results.active_tracer_->recordCCall(local_results, frame, arg); break; - case PyTrace_EXCEPTION: case PyTrace_RETURN: local_results.exit_times_.emplace_back(c10::getApproximateTime()); break; diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index ead46337ff09..b44ce311ecd9 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -20,8 +20,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -426,7 +426,8 @@ PyObject* THCPModule_cudaCachingAllocator_set_allocator_settings( PyObject* _unused, PyObject* env) { HANDLE_TH_ERRORS - c10::CachingAllocator::setAllocatorSettings(THPUtils_unpackString(env)); + c10::cuda::CUDACachingAllocator::setAllocatorSettings( + THPUtils_unpackString(env)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp index acece3d8c718..76ffdd38d264 100644 --- a/torch/csrc/distributed/c10d/Backend.hpp +++ b/torch/csrc/distributed/c10d/Backend.hpp @@ -46,6 +46,7 @@ class TORCH_API Backend : public torch::CustomClassHolder { // backend name // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const std::string backend; + std::string group_name; }; explicit Backend(int rank, int size); @@ -105,6 +106,16 @@ class TORCH_API Backend : public torch::CustomClassHolder { TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented."); } + // Subclasses must override this method to return the backend name + virtual c10::intrusive_ptr getBackendOptions() { + TORCH_CHECK( + false, + c10::str( + "Backend ", + getBackendName(), + " does not implement getBackendOptions.")); + } + virtual c10::intrusive_ptr broadcast( std::vector& /* tensors */, const BroadcastOptions& /* opts */ = BroadcastOptions()) { @@ -379,6 +390,28 @@ class TORCH_API Backend : public torch::CustomClassHolder { " is missing implementation of enableCollectivesTiming."); } + virtual c10::intrusive_ptr split( + const std::vector& ranks, + const c10::intrusive_ptr& opts) { + TORCH_CHECK( + false, + "Backend ", + getBackendName(), + " is missing implementation of split."); + } + + virtual c10::intrusive_ptr merge( + const c10::intrusive_ptr& store, + const c10::intrusive_ptr& opts, + const int& rank, + const int& size) { + TORCH_CHECK( + false, + "Backend ", + getBackendName(), + " is missing implementation of merge."); + } + bool hasHooks() const { return onCompletionHook_ != nullptr; } diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 60bb0f2d879e..8074cc98a04f 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -573,6 +573,27 @@ size_t hashTensors(const std::vector& tensors) { return hash; } +// NCCL uses Non-negative int to represent in-group according to API +// requirement. We take a list of ranks and generate a hash value based on the +// list and ensure its range of 32-bit int. +int genNcclSplitColor(const std::vector& ranks) { + // Combine the hash values using a simple reducer (std::hash + fold) + std::size_t combined_hash = std::accumulate( + ranks.begin(), + ranks.end(), + std::size_t(0), + [](std::size_t acc, int rank) { + return acc ^ + (std::hash{}(rank) + 0x9e3779b9 + (acc << 6) + (acc >> 2)); + }); + + // max positive value of int32_t + constexpr int32_t max_c_int = std::numeric_limits::max(); + int color = static_cast( + std::abs(static_cast(combined_hash)) % max_c_int); + return color; +} + // Default value: 30 minutes int nccl_nonblocking_timeout() { static int timeout = -2; // -2 means not initialized diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 5e61837c2353..fcd55b6a655e 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -231,6 +231,7 @@ static std::map ncclDataType = { }; TORCH_API size_t hashTensors(const std::vector& tensors); +TORCH_API int genNcclSplitColor(const std::vector& ranks); TORCH_API std::string getNcclVersion(); TORCH_API std::tuple getNcclVersionTuple(); TORCH_API int getNcclVersionNumber(); diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index 83418d17acdc..3f183d804129 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -4,7 +4,7 @@ #include #include -#include +#include #include #include @@ -158,6 +158,104 @@ void ProcessGroup::release_resources() { backendTypeToBackend_.clear(); } +c10::intrusive_ptr ProcessGroup::splitGroup( + const std::vector& ranks, + const std::optional timeout, + const std::optional> opts, + const std::optional& desc) { + TORCH_CHECK( + ranks.size() > 0, + "Split ranks cannot be empty. Please provide a non-empty list of ranks to split the group."); + TORCH_CHECK( + ranks.size() < static_cast(size_), + "the split group's size should be less than the world_size set by init_process_group"); + std::set ranks_set(ranks.begin(), ranks.end()); + TORCH_CHECK( + ranks_set.size() == ranks.size(), + "Split ranks should not have duplicates. Please provide a list of unique ranks to split the group."); + std::vector sorted_ranks = ranks; + std::sort(sorted_ranks.begin(), sorted_ranks.end()); + c10::intrusive_ptr newGroup; + // TODO: Figure out a better way for split group name. + std::string groupName = + c10::str(getGroupName(), ":split:", fmt::format("{}", sorted_ranks)); + for (const auto& pair : deviceTypeToBackendType_) { + c10::DeviceType deviceType = pair.first; + BackendType backendType = pair.second; + + auto parentBackend = getBackend(deviceType); + auto backendOpts = + opts.has_value() ? opts.value() : parentBackend->getBackendOptions(); + backendOpts->group_name = groupName; + backendOpts->timeout = + timeout.has_value() ? timeout.value() : backendOpts->timeout; + auto splitBackend = parentBackend->split(sorted_ranks, backendOpts); + if (splitBackend == nullptr) { + continue; + } + + // TODO: Figure out a better way for split group desc. + // TODO: We can add a new field in Backend::Options to specify the group + // desc + std::string groupDesc = desc.has_value() + ? desc.value() + : c10::str(getGroupDesc(), ":split:", incrementSplitCount()); + splitBackend->setGroupDesc(groupDesc); + + if (!newGroup) { + newGroup = c10::make_intrusive( + store_->clone(), splitBackend->getRank(), splitBackend->getSize()); + newGroup->setDefaultBackend(backendType_); + newGroup->setGroupName(groupName); + newGroup->setGroupDesc(groupDesc); + } + newGroup->setBackend(deviceType, backendType, splitBackend); + } + + return newGroup; +} + +c10::intrusive_ptr ProcessGroup::mergeRemoteGroup( + const c10::intrusive_ptr& store, + const MergeOptions& opts, + const int& size) { + c10::intrusive_ptr newGroup; + // We assume rank number is within the range of int32_t, so it won't overflow. + int rank = static_cast(store->add("mergeGroupRank", 1) - 1); + // TODO: Do we need to check all groups have same deviceTypeToBackendType_? + for (const auto& pair : deviceTypeToBackendType_) { + c10::DeviceType deviceType = pair.first; + BackendType backendType = pair.second; + + auto parentBackend = getBackend(deviceType); + auto backendOpts = parentBackend->getBackendOptions(); + std::string groupName = opts.group_name.has_value() + ? opts.group_name.value() + : c10::str(getGroupName(), ":merge"); + backendOpts->group_name = groupName; + backendOpts->timeout = opts.timeout; + auto mergedBackend = parentBackend->merge(store, backendOpts, rank, size); + + std::string groupDesc = opts.group_desc.has_value() + ? opts.group_desc.value() + : c10::str(getGroupDesc(), ":merge"); + mergedBackend->setGroupDesc(groupDesc); + + // Historically, we have been using one process_group to map to all + // backends. but in our new design, we will have one process_group per + // backend. This logic is mostly for backward compatibility. + if (!newGroup) { + newGroup = c10::make_intrusive(store, rank, size); + newGroup->setDefaultBackend(backendType_); + newGroup->setGroupName(groupName); + newGroup->setGroupDesc(groupDesc); + } + newGroup->setBackend(deviceType, backendType, mergedBackend); + } + + return newGroup; +} + } // namespace c10d namespace { diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index da4bf65f4f39..437564ff9ac6 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -71,6 +71,21 @@ C10_EXPORT bool allow_inflight_collective_as_graph_input(); // class TORCH_API ProcessGroup : public torch::CustomClassHolder { public: + struct TORCH_API MergeOptions : torch::CustomClassHolder { + explicit MergeOptions( + const std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout, + const std::optional group_name = std::nullopt, + const std::optional group_desc = std::nullopt) + : timeout(timeout), group_name(group_name), group_desc(group_desc) {} + ~MergeOptions() override = default; + MergeOptions(const MergeOptions&) = delete; + MergeOptions& operator=(const MergeOptions&) = delete; + + std::chrono::milliseconds timeout; + std::optional group_name; + std::optional group_desc; + }; + enum BackendType : uint8_t { UNDEFINED = 0, GLOO = 1, @@ -170,6 +185,10 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { } } + int64_t incrementSplitCount() { + return splitCounter_++; + } + virtual void startCoalescing(c10::DeviceType deviceType) { // only nccl has implemented startCoalescing so only execute for nccl // backends @@ -955,6 +974,21 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { bound_device_id_ = device; } + // This creates a new subgroup using the specified ranks. + // The current rank must be included in the list of new_ranks. + virtual c10::intrusive_ptr splitGroup( + const std::vector& ranks, + const std::optional timeout, + const std::optional> opts, + const std::optional& groupDesc); + + // This creates a new subgroup using the specified ranks. + // The current rank must be included in the list of new_ranks. + virtual c10::intrusive_ptr mergeRemoteGroup( + const c10::intrusive_ptr& store, + const MergeOptions& opts, + const int& size); + protected: // Implementations of this interface need to call this to setup // appropriate logging etc. @@ -968,6 +1002,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) BackendType backendType_; std::string pg_desc_; + int64_t splitCounter_; // Debug level setting. It is parsed once when ProcessGroup is constructed and // remains the same across use of this process group. diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 0df6073c5d2d..895915dcc840 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -697,6 +697,47 @@ const std::vector& ProcessGroupGloo::groupRanks() const { return options_->global_ranks_in_group; } +c10::intrusive_ptr ProcessGroupGloo::split( + const std::vector& ranks, + const c10::intrusive_ptr& opts) { + auto it = std::find(ranks.begin(), ranks.end(), rank_); + int groupRank; + if (it == ranks.end()) { + return nullptr; + } else { + groupRank = std::distance(ranks.begin(), it); + } + + auto glooOpts = c10::dynamic_intrusive_pointer_cast(opts); + TORCH_CHECK(glooOpts != nullptr, "opts not a ProcessGroupGloo::Options."); + + // TODO: we need to get rid of globalRanksInGroup eventually. + std::vector globalRanksInGroup; + for (auto rank : ranks) { + globalRanksInGroup.emplace_back(groupRanks()[rank]); + } + glooOpts->global_ranks_in_group = std::move(globalRanksInGroup); + auto store = std::dynamic_pointer_cast(store_); + TORCH_CHECK( + store != nullptr, + "store inside ProcessGroupGloo not a ProcessGroupGloo::GlooStore."); + auto pg = c10::make_intrusive( + store->_getStore()->clone(), groupRank, ranks.size(), glooOpts); + return c10::static_intrusive_pointer_cast(pg); +} + +c10::intrusive_ptr ProcessGroupGloo::merge( + const c10::intrusive_ptr& store, + const c10::intrusive_ptr& opts, + const int& rank, + const int& size) { + auto glooOpts = c10::dynamic_intrusive_pointer_cast(opts); + TORCH_CHECK(glooOpts != nullptr, "opts not a ProcessGroupGloo::Options."); + auto pg = c10::make_intrusive( + store->clone(), rank, size, glooOpts); + return c10::static_intrusive_pointer_cast(pg); +} + void ProcessGroupGloo::enqueue(c10::intrusive_ptr work) { std::unique_lock lock(workMutex_); pgStatus_->lastEnqueuedSeq = static_cast(work->seq_); diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp index e5f1ca740288..fd3fd779229d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp @@ -188,6 +188,10 @@ class TORCH_API ProcessGroupGloo : public Backend { } #endif + const c10::intrusive_ptr<::c10d::Store>& _getStore() const { + return store_; + } + protected: c10::intrusive_ptr<::c10d::Store> store_; }; @@ -252,7 +256,6 @@ class TORCH_API ProcessGroupGloo : public Backend { } std::vector global_ranks_in_group; - std::string group_name; std::vector> devices; int threads; }; @@ -301,6 +304,20 @@ class TORCH_API ProcessGroupGloo : public Backend { } } + c10::intrusive_ptr getBackendOptions() override { + return c10::static_intrusive_pointer_cast(options_); + } + + c10::intrusive_ptr split( + const std::vector& ranks, + const c10::intrusive_ptr& opts) override; + + c10::intrusive_ptr merge( + const c10::intrusive_ptr& store, + const c10::intrusive_ptr& opts, + const int& rank, + const int& size) override; + const std::vector& groupRanks() const; c10::intrusive_ptr broadcast( diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index fda3879a8e8c..ba335dff8c5f 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -519,11 +519,9 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL( // DEFAULT_FLAGS = cudaEventDisableTiming. if (cudaEventCacheEnabled) { ncclStartEvent_ = enableTiming - ? ProcessGroupNCCL::CUDAEventCache::get(device.index()) - ->create(enableTiming) + ? CUDAEventCache::get(device.index())->create(enableTiming) : nullptr; - ncclEndEvent_ = ProcessGroupNCCL::CUDAEventCache::get(device.index()) - ->create(enableTiming); + ncclEndEvent_ = CUDAEventCache::get(device.index())->create(enableTiming); } else { ncclStartEvent_ = enableTiming ? std::make_shared(cudaEventDefault) @@ -860,61 +858,6 @@ void ProcessGroupNCCL::WorkNCCL::abort() { } } -ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() = default; - -// CUDA event is used to record the start/end of one Work. -// Instead of let the CUDA event gets destroyed, we now reuse it after the Work -// has been erased from workMetaList_. -// This is to avoid the potential deadlock caused by CudaEventDestroy. -std::shared_ptr ProcessGroupNCCL::CUDAEventCache::create( - bool timing) { - // Register the deleter as a callback when the WorkNCCL object is destroyed. - // Each deleter keeps a ref count to the cache object, so that even when - // the thread that creates the cache is gone, the cache object won't be - // destroyed until all the events in the cache are destroyed (ref number drops - // to zero). - auto deleter = [cache = shared_from_this(), - timing](at::cuda::CUDAEvent* event) { - std::lock_guard lock(cache->cacheMutex_); - // We put the event back to the cache deque once the WorkNCCL object is - // destroyed. - cache->eventsArray_[timing ? 1 : 0].push_back(event); - }; - at::cuda::CUDAEvent* event = nullptr; - { - std::lock_guard lock(cacheMutex_); - auto& events = eventsArray_[timing ? 1 : 0]; - // If we still have events in the cache, we reuse it. Otherwise, we create a - // new one. - if (!events.empty()) { - event = events.front(); - events.pop_front(); - } else { - event = new at::cuda::CUDAEvent( - timing ? cudaEventDefault : cudaEventDisableTiming); - } - } - return std::shared_ptr(event, std::move(deleter)); -} - -std::shared_ptr ProcessGroupNCCL:: - CUDAEventCache::get(at::DeviceIndex device) { - // A per-thread singleton of device-to-CUDAEventCache map. - // Map is needed because events cannot be reused across devices. - // Per-thread ownership is needed to support multi-threaded case (instead of - // multi-process case). - static thread_local std:: - map> - cacheDeviceMap; - // Check if device has already been in the map, if not, add a new entry - auto it = cacheDeviceMap.find(device); - if (it == cacheDeviceMap.end()) { - cacheDeviceMap.emplace( - device, std::make_shared()); - } - return cacheDeviceMap[device]; -} - static std::atomic process_group_id = 0; constexpr const char* MULTI_DEVICE_ERROR_MSG = @@ -1311,6 +1254,57 @@ void ProcessGroupNCCL::enableCollectivesTiming() { enableTiming_.store(true); } +c10::intrusive_ptr ProcessGroupNCCL::split( + const std::vector& ranks, + const c10::intrusive_ptr& opts) { + auto deviceIdx = guessDeviceId(); + TORCH_CHECK( + deviceIdx >= 0, + "ProcessGroupNCCL::split: rank ", + rank_, + " has no device is bound to this rank."); + auto device = at::Device(at::DeviceType::CUDA, deviceIdx); + auto it = std::find(ranks.begin(), ranks.end(), rank_); + int groupRank; + if (it == ranks.end()) { + // This rank is not in the new group, so no_color split should be called + performNocolorSplit(device); + return nullptr; + } else { + groupRank = std::distance(ranks.begin(), it); + } + + auto ncclOpts = c10::dynamic_intrusive_pointer_cast(opts); + TORCH_CHECK(ncclOpts != nullptr, "opts not a ProcessGroupNCCL::Options."); + + // TODO: we need to get rid of globalRanksInGroup eventually. + std::vector globalRanksInGroup; + for (auto rank : ranks) { + globalRanksInGroup.emplace_back(groupRanks()[rank]); + } + ncclOpts->split_from = + c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this); + ncclOpts->global_ranks_in_group = std::move(globalRanksInGroup); + auto color = genNcclSplitColor(ranks); + ncclOpts->split_color = color; + auto pg = c10::make_intrusive( + store_->clone(), groupRank, ranks.size(), ncclOpts); + pg->eagerConnectSingleDevice(device); + return c10::static_intrusive_pointer_cast(pg); +} + +c10::intrusive_ptr ProcessGroupNCCL::merge( + const c10::intrusive_ptr& store, + const c10::intrusive_ptr& opts, + const int& rank, + const int& size) { + auto ncclOpts = c10::dynamic_intrusive_pointer_cast(opts); + TORCH_CHECK(ncclOpts != nullptr, "opts not a ProcessGroupNCCL::Options."); + auto pg = c10::make_intrusive( + store->clone(), rank, size, ncclOpts); + return c10::static_intrusive_pointer_cast(pg); +} + bool ProcessGroupNCCL::waitForFutureOrTimeout( std::future& fut, const std::chrono::milliseconds& timeOutMilSec, diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index bf7ac47d8ed1..dd35afc155f3 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -503,23 +504,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { friend class ProcessGroupNCCL; }; - class CUDAEventCache - : public std::enable_shared_from_this { - public: - CUDAEventCache(); - std::shared_ptr create(bool timing); - static std::shared_ptr get( - at::DeviceIndex device); - - private: - std::mutex cacheMutex_; - // NOTE: We intentionally store raw pointers so that - // we do not attempt to destroy the event objects on process exit, - // because cuda may be gone. - std::array, 2> - eventsArray_; // 0 for timing=false, 1 for timing=true - }; - struct Options : Backend::Options { // NOTE: timeout in ProcessGroupNCCL::Options denote the timeout for // operations. This is only used when blockingWait_ is enabled. @@ -541,7 +525,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Optional "parent" backend and color to create communicators from // via `ncclCommSplit` - std::shared_ptr split_from; + c10::intrusive_ptr split_from; // Color to use for `ncclCommSplit`, values: // * Non-negative value: in group; // * NCCL_SPLIT_NOCOLOR (-1): not in group; @@ -562,7 +546,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { int split_color{-2}; #endif std::vector global_ranks_in_group; - std::string group_name; }; // Helper class related to TORCH_NCCL_DESYNC_DEBUG @@ -804,6 +787,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { return options_; } + c10::intrusive_ptr getBackendOptions() override { + return c10::static_intrusive_pointer_cast(options_); + } + const std::string getBackendName() const override { return std::string(NCCL_BACKEND_NAME); } @@ -972,6 +959,16 @@ class TORCH_API ProcessGroupNCCL : public Backend { void enableCollectivesTiming() override; + c10::intrusive_ptr split( + const std::vector& ranks, + const c10::intrusive_ptr& opts) override; + + c10::intrusive_ptr merge( + const c10::intrusive_ptr& store, + const c10::intrusive_ptr& opts, + const int& rank, + const int& size) override; + // Helper function for iteratively aborting communicators in the provided map void abortCommsFromMap( std::unordered_map>& ncclCommsMap, diff --git a/torch/csrc/distributed/c10d/PyProcessGroup.hpp b/torch/csrc/distributed/c10d/PyProcessGroup.hpp index 3355d0feebfb..c54004310517 100644 --- a/torch/csrc/distributed/c10d/PyProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/PyProcessGroup.hpp @@ -151,6 +151,34 @@ class PyProcessGroup : public ProcessGroup { group_desc); } + c10::intrusive_ptr splitGroup( + const std::vector& ranks, + const std::optional timeout, + const std::optional> opts, + const std::optional& group_desc) override { + PYBIND11_OVERRIDE( + c10::intrusive_ptr, /* Return type */ + ProcessGroup, /* Parent class */ + splitGroup, /* Name of function in C++ */ + ranks, + timeout, + opts, + group_desc); + } + + c10::intrusive_ptr mergeRemoteGroup( + const c10::intrusive_ptr& store, + const MergeOptions& opts, + const int& size) override { + PYBIND11_OVERRIDE( + c10::intrusive_ptr, /* Return type */ + ProcessGroup, /* Parent class */ + mergeRemoteGroup, /* Name of function in C++ */ + store, + opts, + size); + } + c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, diff --git a/torch/csrc/distributed/c10d/cuda/CUDAEventCache.cpp b/torch/csrc/distributed/c10d/cuda/CUDAEventCache.cpp new file mode 100644 index 000000000000..75208e92b408 --- /dev/null +++ b/torch/csrc/distributed/c10d/cuda/CUDAEventCache.cpp @@ -0,0 +1,58 @@ +#include +#include +#include + +namespace c10d { + +CUDAEventCache::CUDAEventCache() = default; + +// CUDA event is used to record the start/end of one Work. +// Instead of let the CUDA event gets destroyed, we now reuse it after the Work +// has been erased from workMetaList_. +// This is to avoid the potential deadlock caused by CudaEventDestroy. +std::shared_ptr CUDAEventCache::create(bool timing) { + // Register the deleter as a callback when the WorkNCCL object is destroyed. + // Each deleter keeps a ref count to the cache object, so that even when + // the thread that creates the cache is gone, the cache object won't be + // destroyed until all the events in the cache are destroyed (ref number drops + // to zero). + auto deleter = [cache = shared_from_this(), + timing](at::cuda::CUDAEvent* event) { + std::lock_guard lock(cache->cacheMutex_); + // We put the event back to the cache deque once the WorkNCCL object is + // destroyed. + cache->eventsArray_[timing ? 1 : 0].push_back(event); + }; + at::cuda::CUDAEvent* event = nullptr; + { + std::lock_guard lock(cacheMutex_); + auto& events = eventsArray_[timing ? 1 : 0]; + // If we still have events in the cache, we reuse it. Otherwise, we create a + // new one. + if (!events.empty()) { + event = events.front(); + events.pop_front(); + } else { + event = new at::cuda::CUDAEvent( + timing ? cudaEventDefault : cudaEventDisableTiming); + } + } + return std::shared_ptr(event, std::move(deleter)); +} + +std::shared_ptr CUDAEventCache::get(at::DeviceIndex device) { + // A per-thread singleton of device-to-CUDAEventCache map. + // Map is needed because events cannot be reused across devices. + // Per-thread ownership is needed to support multi-threaded case (instead of + // multi-process case). + static thread_local std::map> + cacheDeviceMap; + // Check if device has already been in the map, if not, add a new entry + auto it = cacheDeviceMap.find(device); + if (it == cacheDeviceMap.end()) { + cacheDeviceMap.emplace(device, std::make_shared()); + } + return cacheDeviceMap[device]; +} + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/cuda/CUDAEventCache.hpp b/torch/csrc/distributed/c10d/cuda/CUDAEventCache.hpp new file mode 100644 index 000000000000..5639c1f04fd7 --- /dev/null +++ b/torch/csrc/distributed/c10d/cuda/CUDAEventCache.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace c10d { + +class TORCH_API CUDAEventCache + : public std::enable_shared_from_this { + public: + CUDAEventCache(); + std::shared_ptr create(bool timing); + static std::shared_ptr get(at::DeviceIndex device); + + private: + std::mutex cacheMutex_; + // NOTE: We intentionally store raw pointers so that + // we do not attempt to destroy the event objects on process exit, + // because cuda may be gone. + std::array, 2> + eventsArray_; // 0 for timing=false, 1 for timing=true +}; + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/cuda/StreamBlock.cu b/torch/csrc/distributed/c10d/cuda/StreamBlock.cu index 58533ece6af8..db4a118a25e5 100644 --- a/torch/csrc/distributed/c10d/cuda/StreamBlock.cu +++ b/torch/csrc/distributed/c10d/cuda/StreamBlock.cu @@ -1,5 +1,7 @@ +#include #include #include +#include #include #include #include @@ -8,7 +10,7 @@ #include #include #else -#include +#include #endif namespace c10d::cuda::detail { @@ -21,19 +23,49 @@ __device__ void nanosleep(int64_t ns) { #endif } +__device__ int32_t load_cpu_int32(int32_t* ptr) { +#if defined(USE_ROCM) + // WARNING: this may not be safe + return atomicAdd_system(ptr, 0); +#else + int32_t current_value = 0; + + // Bypass L1 cache to see updates at L2 and above. + // This could use .cv to bypass L2 cache but that's significantly more + // expensive and the CPU write will clear the L2 cache. + // https://docs.nvidia.com/cuda/parallel-thread-execution/#cache-operators + asm volatile("ld.cg.s32 %0, [%1];" + : "=r"(current_value) // Output operand + : "l"(ptr) // Input operand + ); + return current_value; +#endif +} + +__device__ void store_cpu_int32(int32_t* ptr, int32_t val) { +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)) + // WARNING: this value may be cached without .release + *ptr = val; +#else + // Releases memory so it can be seen by other threads on the system. + // https://docs.nvidia.com/cuda/parallel-thread-execution/#release-acquire-patterns + asm volatile("st.release.sys.s32 [%0], %1;" ::"l"(ptr), "r"(val)); +#endif +} + __global__ // set launch bounds to limit to 1 thread per block, 1 block per MP __launch_bounds__(1, 1) void kernel_barrier(int32_t* value, size_t timeout_ms) { - value[1] = StreamBlockStatus::RUNNING; + store_cpu_int32(&value[1], StreamBlockStatus::RUNNING); size_t start = c10d::symmetric_memory::global_timer_ns(); size_t timeout_ns = timeout_ms * 1e6; // Convert milliseconds to nanoseconds while (true) { // Atomically read the value - int current_value = atomicAdd(&value[0], 0); + int32_t current_value = load_cpu_int32(value); // Check if the value is equal to the expected value if (current_value == 1) { - value[1] = StreamBlockStatus::ABORTED; + store_cpu_int32(&value[1], StreamBlockStatus::ABORTED); return; } @@ -41,7 +73,7 @@ __launch_bounds__(1, 1) void kernel_barrier(int32_t* value, size_t timeout_ms) { // Check if timeout has been reached size_t now = c10d::symmetric_memory::global_timer_ns(); if ((now - start) > timeout_ns) { - value[1] = StreamBlockStatus::TIMED_OUT; + store_cpu_int32(&value[1], StreamBlockStatus::TIMED_OUT); return; } } @@ -55,13 +87,21 @@ StreamBlock::StreamBlock(std::chrono::milliseconds timeout) : comm_{ // We need to pin the memory since we access the CPU memory directly form // the GPU. - at::empty({2}, at::TensorOptions().dtype(at::kInt)).pin_memory() + at::zeros({2}, at::TensorOptions().dtype(at::kInt)).pin_memory() }, timeout_{timeout} { + auto stream = at::cuda::getCurrentCUDAStream(); + auto* ptr = comm_.mutable_data_ptr(); + auto* ctx = comm_.storage().data_ptr().get_context(); + // grid size 1, block size 1, 0 bytes of shared memory - kernel_barrier<<<1, 1, 0>>>( - comm_.mutable_data_ptr(), timeout_.count()); + kernel_barrier<<<1, 1, 0, stream>>>(ptr, timeout_.count()); C10_CUDA_KERNEL_LAUNCH_CHECK(); + + // This object may be deallocated before the CUDA kernel completes. We need to + // register the CPU tensor so it's only freed after the kernel completes + // execution. + at::getHostAllocator(at::kCUDA)->record_event(ptr, ctx, stream.unwrap()); } C10_REGISTER_CLASS(StreamBlockRegistry, CUDA, StreamBlock) diff --git a/torch/csrc/distributed/c10d/cuda/StreamBlock.cuh b/torch/csrc/distributed/c10d/cuda/StreamBlock.cuh index f94f272d7eef..9ca52b4c5e88 100644 --- a/torch/csrc/distributed/c10d/cuda/StreamBlock.cuh +++ b/torch/csrc/distributed/c10d/cuda/StreamBlock.cuh @@ -13,6 +13,7 @@ class StreamBlock : public ::c10d::cuda::StreamBlock { StreamBlock(std::chrono::milliseconds timeout); void abort() override { + std::atomic_thread_fence(std::memory_order_seq_cst); comm_[0] = 1; } diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 0121bd6fd94b..8f617c269ff9 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2063,6 +2063,34 @@ communication mechanism. .def("rank", &::c10d::ProcessGroup::getRank, R"(Get the rank of this process group.)") .def("size", &::c10d::ProcessGroup::getSize, R"(Get the size of this process group.)") .def("name", &::c10d::ProcessGroup::getBackendName, R"(Get the name of this process group.)") + .def( + "split_group", + &::c10d::ProcessGroup::splitGroup, + py::arg("ranks"), + py::arg("timeout") = std::nullopt, + py::arg("opts") = std::nullopt, + py::arg("groupDesc") = std::nullopt, + py::call_guard()) + .def( + "merge_remote_group", + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const c10::intrusive_ptr<::c10d::Store>& store, + int size, + std::chrono::milliseconds timeout, + std::optional groupName, + std::optional groupDesc) { + ::c10d::ProcessGroup::MergeOptions opts; + opts.timeout = timeout; + opts.group_name = groupName; + opts.group_desc = groupDesc; + return self->mergeRemoteGroup(store, opts, size); + }, + py::arg("store"), + py::arg("size"), + py::arg("timeout") = kProcessGroupDefaultTimeout, + py::arg("group_name") = std::nullopt, + py::arg("group_desc") = std::nullopt, + py::call_guard()) .def( "abort", &::c10d::ProcessGroup::abort, diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h index bf5ea9a446bb..0abbc84ebe52 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h @@ -115,54 +115,44 @@ __device__ __forceinline__ void wait_signal(uint32_t* addr) { // Pattern 0: Ensures that all writes to symm_mem buffers from previous // kernels across all devices are visible to the current kernel: // -// sync_remote_blocks(...); +// sync_remote_blocks(...); // __syncthreads(); // // Pattern 1: Ensures that all writes to symm_mem buffers from the current // block are visible to all remote blocks with matching blockIdx: // // __syncthreads(); -// sync_remote_blocks(...); +// sync_remote_blocks(...); // __syncthreads(); // // Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe // for writing by subsequent kernels across all devices. // // __syncthreads(); -// sync_remote_blocks(...); -template +// sync_remote_blocks(...); +template __device__ __forceinline__ void sync_remote_blocks( - uint32_t** signal_pads, - size_t rank, - size_t world_size); - -template <> -__device__ __forceinline__ void sync_remote_blocks( - uint32_t** signal_pads, - size_t rank, - size_t world_size) { - if (threadIdx.x < world_size) { - auto target_rank = threadIdx.x; - put_signal( - signal_pads[target_rank] + blockIdx.x * world_size + rank); - wait_signal( - signal_pads[rank] + blockIdx.x * world_size + target_rank); - } -} - -template <> -__device__ __forceinline__ void sync_remote_blocks( uint32_t** signal_pads, size_t rank, size_t world_size) { if (threadIdx.x < world_size) { auto target_rank = threadIdx.x; - put_signal( - signal_pads[target_rank] + blockIdx.x * world_size + rank); - wait_signal( - signal_pads[rank] + blockIdx.x * world_size + target_rank); + if constexpr (hasPrevMemAccess) { + put_signal( + signal_pads[target_rank] + blockIdx.x * world_size + rank); + } else { + put_signal( + signal_pads[target_rank] + blockIdx.x * world_size + rank); + } + if constexpr (hasSubsequentMemAccess) { + wait_signal( + signal_pads[rank] + blockIdx.x * world_size + target_rank); + } else { + wait_signal( + signal_pads[rank] + blockIdx.x * world_size + target_rank); + } } -} +}; template struct MultimemLdReduce { diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu index c4f38e468192..3a004ae73ce7 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu @@ -134,7 +134,7 @@ static __global__ void multimem_all_reduce_kernel( static_assert(alignment % sizeof(T) == 0); constexpr size_t numel_per_thread = alignment / sizeof(T); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); __syncthreads(); const size_t numel_per_rank = @@ -152,7 +152,7 @@ static __global__ void multimem_all_reduce_kernel( } __syncthreads(); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); } at::Tensor multimem_all_reduce_( @@ -219,7 +219,7 @@ static __global__ void multimem_one_shot_all_reduce_kernel( static_assert(alignment % sizeof(T) == 0); constexpr size_t numel_per_thread = alignment / sizeof(T); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); __syncthreads(); auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread; @@ -230,7 +230,7 @@ static __global__ void multimem_one_shot_all_reduce_kernel( } __syncthreads(); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); } at::Tensor multimem_one_shot_all_reduce_out( @@ -311,7 +311,7 @@ static __global__ void multimem_all_gather_kernel( uint32_t** signal_pads, size_t rank, size_t world_size) { - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); __syncthreads(); const size_t start = bytes_per_rank * rank; @@ -324,7 +324,7 @@ static __global__ void multimem_all_gather_kernel( } __syncthreads(); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); } at::Tensor multimem_all_gather_out( @@ -425,7 +425,7 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__ } } // TODO make it sync with one block for no-copy case - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); __syncthreads(); for (size_t i = offset; i < numel; i += stride) { @@ -435,7 +435,7 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__ } __syncthreads(); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); } at::Tensor one_shot_all_reduce_out_impl( @@ -587,7 +587,7 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ constexpr size_t numel_per_thread = alignment / sizeof(T); int32_t N_last_dim = last_dim_size / world_size; // used only for split_last_dim reduce_scatter - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); __syncthreads(); const size_t numel_per_rank = @@ -619,7 +619,7 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ } __syncthreads(); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); if constexpr (reduce_scatter) { return; } @@ -654,7 +654,7 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ // need to make sure all blocks exit simultaneously so that the data // is not corrupted by the subsequent kernels __syncthreads(); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); } template @@ -669,7 +669,7 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ static_assert(alignment % sizeof(T) == 0); constexpr size_t numel_per_thread = alignment / sizeof(T); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); __syncthreads(); const size_t numel_per_rank = @@ -692,7 +692,7 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ } __syncthreads(); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); } at::Tensor two_shot_all_reduce_impl( diff --git a/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu index 4f69c4974386..55695ca27c8e 100644 --- a/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu @@ -98,7 +98,7 @@ class NCCLSymmetricMemory : public SymmetricMemory { int rank, c10::IntArrayRef sizes, c10::ScalarType dtype, - int64_t storage_offset) { + int64_t storage_offset) override { // TODO: deduplicate const size_t numel = std::accumulate( sizes.begin(), diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index c25e83c07c6d..190752070250 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -8,8 +8,10 @@ #include C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated") +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include C10_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() #include #include diff --git a/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp b/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp index 4c326b6a0e27..03b43184d143 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_cuda.cpp @@ -7,10 +7,12 @@ #include #include +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated") #include #include C10_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() namespace torch::distributed::rpc { namespace { diff --git a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp index 86308ae6cdf3..f28aefc06dee 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp @@ -6,8 +6,10 @@ #include C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated") +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include C10_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_POP() namespace torch::distributed::rpc { namespace { diff --git a/torch/csrc/dynamo/compiled_autograd.h b/torch/csrc/dynamo/compiled_autograd.h index cba8158213c6..87689f34dfae 100644 --- a/torch/csrc/dynamo/compiled_autograd.h +++ b/torch/csrc/dynamo/compiled_autograd.h @@ -1106,14 +1106,16 @@ struct IValuePacker { // That's what the TypePtr is for: it contains the information to do the // parsing. See torch::jit::toIValue for more information. static at::TypePtr packed_type() { -#ifdef _WIN32 +#if defined(_WIN32) +#if defined(USE_CUDA) || defined(USE_ROCM) // NB: the if-constexpr usage triggers compilation errors on Windows // with certain compiler settings // (see https://github.com/pytorch/pytorch/pull/144707 for examples). // It's not clear what the problem is, so we're going to ignore it for now. TORCH_CHECK_NOT_IMPLEMENTED( - false, "torch.compile not supported on Windows"); -#else + false, "torch.compile not supported on Windows GPU"); +#endif +#endif if constexpr (::std::is_same_v) { return at::TensorType::get(); } else if constexpr (::std::is_same_v) { @@ -1153,7 +1155,6 @@ struct IValuePacker { false, "IValuePacker not implemented for type"); return at::NoneType::get(); } -#endif } }; diff --git a/torch/csrc/dynamo/cpython_defs.c b/torch/csrc/dynamo/cpython_defs.c index b68ef894aeaa..244d4165d5e8 100644 --- a/torch/csrc/dynamo/cpython_defs.c +++ b/torch/csrc/dynamo/cpython_defs.c @@ -2,6 +2,20 @@ #include #include +#if IS_PYTHON_3_14_PLUS + +const uint8_t* THP_PyOpcode_Caches = NULL; +const int THP_PyOpcode_Caches_size = 0; + +void +THP_PyThreadState_PopFrame(PyThreadState *tstate, _PyInterpreterFrame * frame) +{} +void +THP_PyFrame_Clear(_PyInterpreterFrame *frame) +{} + +#else + #if IS_PYTHON_3_11_PLUS #define Py_BUILD_CORE @@ -360,3 +374,5 @@ const uint8_t* THP_PyOpcode_Caches = NULL; const int THP_PyOpcode_Caches_size = 0; #endif + +#endif // IS_PYTHON_3_14_PLUS \ No newline at end of file diff --git a/torch/csrc/dynamo/cpython_includes.h b/torch/csrc/dynamo/cpython_includes.h index 6b99c1d5aec8..616be16563cf 100644 --- a/torch/csrc/dynamo/cpython_includes.h +++ b/torch/csrc/dynamo/cpython_includes.h @@ -21,6 +21,14 @@ #if IS_PYTHON_3_11_PLUS #include +#if IS_PYTHON_3_14_PLUS +#include +#include +#endif +#endif + +#if IS_PYTHON_3_14_PLUS +#include #endif #undef Py_BUILD_CORE @@ -30,6 +38,13 @@ extern "C" { #endif +#if IS_PYTHON_3_14_PLUS + +#define F_CODE(x) (PyCodeObject*)PyStackRef_AsPyObjectBorrow(x->f_executable) +#define PREV_INSTR(x) (x)->instr_ptr + +#else + #if IS_PYTHON_3_13_PLUS #define F_CODE(x) ((PyCodeObject*)(x)->f_executable) #define PREV_INSTR(x) (x)->instr_ptr @@ -38,6 +53,8 @@ extern "C" { #define PREV_INSTR(x) (x)->prev_instr #endif +#endif // IS_PYTHON_3_14_PLUS + #if IS_PYTHON_3_12_PLUS #define FUNC(x) ((x)->f_funcobj) #else diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c index f413782b2d30..7d00c7ba1abf 100644 --- a/torch/csrc/dynamo/eval_frame.c +++ b/torch/csrc/dynamo/eval_frame.c @@ -34,6 +34,14 @@ void eval_frame_callback_set(PyObject* obj) { PyThread_tss_set(&eval_frame_callback_key, obj); } +#if IS_PYTHON_3_12_PLUS +const size_t sys_monitoring_num_callables = + sizeof((PyInterpreterState){0}.monitoring_callables) / sizeof(PyObject*); +PyObject** get_monitoring_callables(PyInterpreterState* interp) { + return (PyObject**)interp->monitoring_callables; +} +#endif + // 3.14 Not supported at all. See cpython_defs.c for hints #if !(IS_PYTHON_3_14_PLUS) @@ -224,17 +232,6 @@ const char* get_frame_name(THP_EVAL_API_FRAME_OBJECT* frame) { return PyUnicode_AsUTF8(F_CODE(frame)->co_name); } -void clear_old_frame_if_python_312_plus( - PyThreadState* tstate, - THP_EVAL_API_FRAME_OBJECT* frame) { -#if IS_PYTHON_3_12_PLUS - - THP_PyFrame_Clear(frame); - THP_PyThreadState_PopFrame(tstate, frame); - -#endif -} - static PyObject* dynamo_eval_custom_code_impl( PyThreadState* tstate, THP_EVAL_API_FRAME_OBJECT* frame, @@ -485,6 +482,18 @@ static PyObject* dynamo__custom_eval_frame_shim( static void enable_eval_frame_shim(PyThreadState* tstate) {} static void enable_eval_frame_default(PyThreadState* tstate) {} +PyObject* dynamo_eval_custom_code( + PyThreadState* tstate, + THP_EVAL_API_FRAME_OBJECT* frame, + PyCodeObject* code, + const char* trace_annotation, + int throw_flag) {} +THPPyInterpreterFrame* THPPyInterpreterFrame_New( + THP_EVAL_API_FRAME_OBJECT* frame) {} +PyObject* dynamo_eval_frame_default( + PyThreadState* tstate, + THP_EVAL_API_FRAME_OBJECT* frame, + int throw_flag) {} static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {NULL}; @@ -498,6 +507,17 @@ static PyTypeObject THPPyInterpreterFrameType = { #endif // !(IS_PYTHON_3_14_PLUS) +void clear_old_frame_if_python_312_plus( + PyThreadState* tstate, + THP_EVAL_API_FRAME_OBJECT* frame) { +#if IS_PYTHON_3_12_PLUS + + THP_PyFrame_Clear(frame); + THP_PyThreadState_PopFrame(tstate, frame); + +#endif +} + static PyObject* increment_working_threads( PyThreadState* tstate, PyObject* module) { @@ -570,6 +590,23 @@ static PyObject* set_eval_frame_py(PyObject* module, PyObject* callback) { "python enabled=%d and is run_only=%d", callback != Py_None, callback == Py_False); +#if IS_PYTHON_3_12_PLUS + // skip tracing sys.monitoring callables + if (callback != Py_None && callback != Py_False) { + PyInterpreterState* interp = PyThreadState_GET()->interp; + PyObject** monitoring_callables_flat = + (PyObject**)interp->monitoring_callables; + for (size_t i = 0; i < sys_monitoring_num_callables; ++i) { + PyObject* callable = monitoring_callables_flat[i]; + if (callable != NULL && PyFunction_Check(callable)) { + PyFunctionObject* func = (PyFunctionObject*)callable; + if (func->func_code != NULL) { + skip_code_recursive((PyCodeObject*)func->func_code); + } + } + } + } +#endif return set_eval_frame(callback, PyThreadState_GET(), module); } diff --git a/torch/csrc/dynamo/eval_frame.h b/torch/csrc/dynamo/eval_frame.h index 870603262ddb..e8742e37fb63 100644 --- a/torch/csrc/dynamo/eval_frame.h +++ b/torch/csrc/dynamo/eval_frame.h @@ -11,6 +11,11 @@ PyObject* torch_c_dynamo_eval_frame_init(void); #endif +#if IS_PYTHON_3_12_PLUS +extern const size_t sys_monitoring_num_callables; +PyObject** get_monitoring_callables(PyInterpreterState* interp); +#endif + // All the eval APIs change in 3.11 so we need to decide which one to use on the // fly https://docs.python.org/3/c-api/init.html#c._PyFrameEvalFunction #if IS_PYTHON_3_11_PLUS diff --git a/torch/csrc/dynamo/eval_frame_cpp.cpp b/torch/csrc/dynamo/eval_frame_cpp.cpp index e05de24259e0..1d42722afaf9 100644 --- a/torch/csrc/dynamo/eval_frame_cpp.cpp +++ b/torch/csrc/dynamo/eval_frame_cpp.cpp @@ -7,6 +7,10 @@ #include #include +#include +#include +#include + extern "C" { extern PyObject* guard_complete_hook; } @@ -335,3 +339,14 @@ PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* args) { extra_state_set_exec_strategy(extra, strategy); Py_RETURN_NONE; } + +void skip_code_recursive(PyCodeObject* code) { + ExtraState* extra = get_extra_state(code); + if (extra == nullptr) { + extra = init_and_set_extra_state(code); + } + + FrameExecStrategy strategy = + FrameExecStrategy{FrameAction::SKIP, FrameAction::SKIP}; + extra_state_set_exec_strategy(extra, strategy); +} diff --git a/torch/csrc/dynamo/eval_frame_cpp.h b/torch/csrc/dynamo/eval_frame_cpp.h index ebbad47ef81b..2f3587094f76 100644 --- a/torch/csrc/dynamo/eval_frame_cpp.h +++ b/torch/csrc/dynamo/eval_frame_cpp.h @@ -17,6 +17,7 @@ PyObject* dynamo__custom_eval_frame( PyObject* callback); PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* obj); +void skip_code_recursive(PyCodeObject* code); #ifdef __cplusplus diff --git a/torch/csrc/dynamo/framelocals_mapping.cpp b/torch/csrc/dynamo/framelocals_mapping.cpp index b839fb26fc91..c4ee36d87767 100644 --- a/torch/csrc/dynamo/framelocals_mapping.cpp +++ b/torch/csrc/dynamo/framelocals_mapping.cpp @@ -26,9 +26,13 @@ FrameLocalsMapping::FrameLocalsMapping(FrameLocalsFrameType* frame) PyCodeObject* co = F_CODE(frame); _framelocals.resize(co->co_nlocalsplus, nullptr); +#if IS_PYTHON_3_14_PLUS + TORCH_CHECK(false, "Python 3.14+ not supported"); +#else if (!frame->stacktop) { return; } +#endif auto update_framelocals = [&](int i, PyObject* value) { _PyLocals_Kind kind = _PyLocals_GetKind(co->co_localspluskinds, i); @@ -53,11 +57,21 @@ FrameLocalsMapping::FrameLocalsMapping(FrameLocalsFrameType* frame) }; auto offset = co->co_nlocalsplus - co->co_nfreevars; +#if IS_PYTHON_3_14_PLUS + TORCH_CHECK(false, "Python 3.14+ not supported"); +#else for (int i = 0; i < offset; i++) { update_framelocals(i, frame->localsplus[i]); } +#endif + // Get references to closure variables +#if IS_PYTHON_3_14_PLUS + PyObject* closure; + TORCH_CHECK(false, "Python 3.14+ not supported"); +#else PyObject* closure = ((PyFunctionObject*)FUNC(frame))->func_closure; +#endif for (int i = 0; i < co->co_nfreevars; i++) { update_framelocals(offset + i, PyTuple_GET_ITEM(closure, i)); } diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 83fb0adbe6c9..eb0f20f1c86e 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -36,6 +36,45 @@ #include #include +// Uncomment next line to count instructions for guard eval. +// #define GUARD_INSTRUCTION_COUNT +#ifdef GUARD_INSTRUCTION_COUNT +#include +#include +#include +#include +#include +#include + +int open_counter() { + perf_event_attr attr{}; + attr.type = PERF_TYPE_HARDWARE; + attr.size = sizeof(attr); + attr.config = PERF_COUNT_HW_INSTRUCTIONS; // retired instructions + attr.disabled = 1; // start stopped + attr.exclude_kernel = 1; // user-space only + attr.exclude_hv = 1; + + return syscall(__NR_perf_event_open, &attr, 0, -1, -1, 0); +} + +uint64_t count_instructions(const std::function& fn) { + int fd = open_counter(); + if (fd == -1) + throw std::runtime_error("perf_event_open failed"); + + ioctl(fd, PERF_EVENT_IOC_RESET, 0); + ioctl(fd, PERF_EVENT_IOC_ENABLE, 0); + fn(); // run the code you care about + ioctl(fd, PERF_EVENT_IOC_DISABLE, 0); + + uint64_t count; + read(fd, &count, sizeof(count)); + close(fd); + return count; +} +#endif + // Certain CPython data structures are defined in `.c` files in earlier Python // versions, e.g., for TupleIteratorGetItemAccessor, we need a fast way to // retrieve the underlying tuple and access the item. Before Python 3.12 @@ -1100,6 +1139,22 @@ std::string get_exception_message() { return exc_message; } +bool is_nn_module(py::handle example_value) { + py::object torch_module_cls = py::module_::import("torch.nn").attr("Module"); + return py::isinstance(example_value, torch_module_cls); +} + +std::string get_type_str(py::handle example_value) { + std::string type_name; + try { + type_name = py::str(py::type::of(example_value)).cast(); + } catch (const py::error_already_set& e) { + // Fallback that never throws in release builds + type_name = ""; + } + return type_name; +} + bool is_immutable_object(py::handle example_value) { py::object config_module = py::module_::import("torch._dynamo.config"); @@ -2515,9 +2570,13 @@ class GuardManager { py::handle example_value) : _root(root), _source(std::move(source)), - _is_dict(py::isinstance(example_value)) { + _is_dict(py::isinstance(example_value)), + _is_immutable(is_immutable_object(example_value)), + _is_nn_module(is_nn_module(example_value)), + _type_str(get_type_str(example_value)) { if (_is_dict) { _dict_tag = get_dict_version_unchecked(example_value.ptr()); + _is_empty_dict = PyDict_Size(example_value.ptr()) == 0; } } @@ -2537,10 +2596,45 @@ class GuardManager { _leaf_guards.emplace_back(std::move(leaf_guard)); } + public: + // type related helpers + bool is_guarded_value_immutable() { + return _is_immutable; + } + + bool is_guarded_value_nn_module() { + return _is_nn_module; + } + + bool is_guarded_value_dict() { + return _is_dict; + } + + bool is_guarded_value_empty_dict() { + return _is_empty_dict; + } + + std::string type_of_guarded_value() { + return _type_str; + } + public: // For cloning - GuardManager(RootGuardManager* root, std::string source, bool is_dict) - : _root(root), _source(std::move(source)), _is_dict(is_dict) {} + GuardManager( + RootGuardManager* root, + std::string source, + bool is_dict, + bool is_empty_dict, + bool is_immutable, + bool is_nn_module, + std::string type_str) + : _root(root), + _source(std::move(source)), + _is_dict(is_dict), + _is_empty_dict(is_empty_dict), + _is_immutable(is_immutable), + _is_nn_module(is_nn_module), + _type_str(std::move(type_str)) {} void clone_common( RootGuardManager* cloned_root, @@ -2571,7 +2665,14 @@ class GuardManager { if (!py::cast(clone_filter_fn(this))) { return nullptr; } - GuardManager* cloned_mgr = new GuardManager(cloned_root, _source, _is_dict); + GuardManager* cloned_mgr = new GuardManager( + cloned_root, + _source, + _is_dict, + _is_empty_dict, + _is_immutable, + _is_nn_module, + _type_str); clone_common(cloned_root, cloned_mgr, clone_filter_fn); return cloned_mgr; } @@ -2851,7 +2952,11 @@ class GuardManager { // to enable fail fast for the next check. std::vector> _accessors; - bool _is_dict; + bool _is_dict = false; + bool _is_empty_dict = false; + bool _is_immutable = false; + bool _is_nn_module = false; + std::string _type_str; uint64_t _dict_tag{0}; }; @@ -3137,7 +3242,7 @@ class DictGuardManager : public GuardManager { RootGuardManager* root, std::string source, py::handle example_value) - : GuardManager(root, std::move(source)), + : GuardManager(root, std::move(source), example_value), _size(PyDict_Size(example_value.ptr())), _expected_type(Py_TYPE(example_value.ptr())), _is_exact_dict_type(PyDict_CheckExact(example_value.ptr())) {} @@ -3352,8 +3457,17 @@ class DictGuardManager : public GuardManager { Py_ssize_t size, PyTypeObject* expected_type, bool is_exact_dict_type, - std::vector indices) - : GuardManager(cloned_root, std::move(source), true), + std::vector indices, + std::string type_of, + bool is_empty_dict) + : GuardManager( + cloned_root, + std::move(source), + true, // _is_dict + is_empty_dict, + false, // _is_nn_module + false, // _is_immutable + std::move(type_of)), _size(size), _expected_type(expected_type), _is_exact_dict_type(is_exact_dict_type), @@ -3372,7 +3486,9 @@ class DictGuardManager : public GuardManager { _size, _expected_type, _is_exact_dict_type, - _indices); + _indices, + type_of_guarded_value(), + is_guarded_value_empty_dict()); clone_common(cloned_root, cloned_mgr, clone_filter_fn); for (auto index : _indices) { @@ -3495,7 +3611,7 @@ std::unique_ptr make_guard_manager( throw py::type_error("Invalid guard manager enum"); } } - return std::make_unique(root, std::move(source)); + return std::make_unique(root, std::move(source), example_value); } class TORCH_FUNCTION_MODE_STACK : public LeafGuard { @@ -3840,6 +3956,13 @@ class GetGenericDictGuardAccessor : public GuardAccessor { // check_verbose_nopybind. bool check_nopybind(PyObject* obj, bool matches_dict_tag = false) override { // borrowed ref + // NOTE for future guard optimization developers - We tried saving the dict + // pointer and weakref of the original object to avoid calling + // PyObject_GenericGetDict on a fast path, but this did not lead any + // meaningful speedups because of 2 reasons + // 1) Once __dict__ is generated, accessing it the second time is fast. + // 2) Getting the object from weakref, from 3.13 onwards, requires + // Py_DECREF, which further eats into the benefit. PyObject* x = PyObject_GenericGetDict(obj, nullptr); // new ref if (x == nullptr) { // Attribute absent, clear the exception and return false. @@ -5482,6 +5605,7 @@ void install_storage_overlapping_guard( /* overlapping= */ false); } +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-volatile") char flush_cache_by_eviction() { constexpr size_t evict_size = 32 * 1024 * 1024; std::vector buffer(evict_size, 1); @@ -5492,6 +5616,7 @@ char flush_cache_by_eviction() { } return sink; } +C10_DIAGNOSTIC_POP() double profile_guard_manager( RootGuardManager* root, @@ -5547,6 +5672,13 @@ bool run_root_guard_manager(void* root, FrameLocalsMapping* f_locals) { if (root == nullptr) { return false; } + +#ifdef GUARD_INSTRUCTION_COUNT + auto n = count_instructions( + [&] { ((RootGuardManager*)root)->check_nopybind(f_locals); }); + std::cout << "#instructions in guard eval = " << n << std::endl << std::flush; +#endif + return ((RootGuardManager*)root)->check_nopybind(f_locals); } @@ -5870,6 +6002,17 @@ PyObject* torch_c_dynamo_guards_init() { // return by reference because GuardManager has the ownership of accessors .def("get_source", &GuardManager::get_source) .def("fail_count", &GuardManager::fail_count) + .def( + "is_guarded_value_immutable", + &GuardManager::is_guarded_value_immutable) + .def( + "is_guarded_value_nn_module", + &GuardManager::is_guarded_value_nn_module) + .def("is_guarded_value_dict", &GuardManager::is_guarded_value_dict) + .def( + "is_guarded_value_empty_dict", + &GuardManager::is_guarded_value_empty_dict) + .def("type_of_guarded_value", &GuardManager::type_of_guarded_value) .def( "get_accessors", &GuardManager::get_accessors, diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index dacdc9eac388..b835b1a00821 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -36,6 +37,29 @@ namespace fs = std::filesystem; #endif namespace { + +const std::string k_separator = "/"; + +std::string normalize_path_separator(const std::string& orig_path) { + /* + On Windows and Linux have different separator: + On Windows use "\", and the path like: C:\Users\Test\file.txt + On Linux use "/", and the path like: /home/user/file.txt + + In order to simplify the path operation, we can use this function to + normalize path separator. It will convert Windows separator to Linux + separator, and reuse the common code to handle both Windows and Linux + path. + On Windows, when we input: "C:\Users\Test\file.txt", the output should be: + "C:/Users/Test/file.txt". And then, we can process the output like on Linux. + */ + std::string normalized_path = orig_path; +#ifdef _WIN32 + std::replace(normalized_path.begin(), normalized_path.end(), '\\', '/'); +#endif + return normalized_path; +} + bool file_exists(const std::string& path) { #ifdef _WIN32 return fs::exists(path); @@ -47,7 +71,16 @@ bool file_exists(const std::string& path) { std::string create_temp_dir() { #ifdef _WIN32 - throw std::runtime_error("Not implemented"); + try { + fs::path temp_dir = fs::temp_directory_path(); + return temp_dir.string(); + } catch (const fs::filesystem_error& e) { + throw std::runtime_error( + "Failed to get temporary directory: " + std::string(e.what())); + } catch (...) { + throw std::runtime_error( + "Unknown error occurred while getting temporary directory"); + } #else std::string temp_dir = "/tmp/XXXXXX"; if (mkdtemp(temp_dir.data()) == nullptr) { @@ -59,11 +92,29 @@ std::string create_temp_dir() { #endif } +const char* object_file_ext() { #ifdef _WIN32 -const std::string k_separator = "\\"; + return ".obj"; #else -const std::string k_separator = "/"; + return ".o"; +#endif +} + +const char* extension_file_ext() { +#ifdef _WIN32 + return ".pyd"; +#else + return ".so"; +#endif +} + +bool _is_windows_os() { +#ifdef _WIN32 + return true; +#else + return false; #endif +} } // namespace namespace torch::inductor { @@ -83,11 +134,12 @@ const nlohmann::json& load_json_file(const std::string& json_path) { } std::tuple get_cpp_compile_command( - const std::string& filename, + const std::string& arg_filename, const std::vector& sources, const nlohmann::json& compile_options, const std::string& output_dir = "") { // Construct the cpp command + auto filename = normalize_path_separator(arg_filename); std::string compiler = compile_options["compiler"].get(); bool compile_only = compile_options["compile_only"].get(); @@ -97,7 +149,8 @@ std::tuple get_cpp_compile_command( source_args += source + " "; } - std::string file_ext = compile_only ? ".o" : ".so"; + std::string file_ext = + compile_only ? object_file_ext() : extension_file_ext(); std::string target_file = output_dir + filename + file_ext; std::string target_dir = output_dir; if (target_dir.empty()) { @@ -107,62 +160,88 @@ std::tuple get_cpp_compile_command( std::string cflags_args; for (auto& arg : compile_options["cflags"]) { - cflags_args += "-" + arg.get() + " "; + cflags_args += _is_windows_os() ? "/" : "-" + arg.get() + " "; } std::string definitions_args; for (auto& arg : compile_options["definitions"]) { - definitions_args += "-D " + arg.get() + " "; + definitions_args += + _is_windows_os() ? "/D" : "-D " + arg.get() + " "; } std::string include_dirs_args; for (auto& arg : compile_options["include_dirs"]) { - include_dirs_args += "-I" + arg.get() + " "; + include_dirs_args += + _is_windows_os() ? "/I" : "-I" + arg.get() + " "; } std::string ldflags_args; for (auto& arg : compile_options["ldflags"]) { - ldflags_args += "-" + arg.get() + " "; + ldflags_args += _is_windows_os() ? "/" : "-" + arg.get() + " "; } std::string libraries_dirs_args; for (auto& arg : compile_options["libraries_dirs"]) { - libraries_dirs_args += "-L" + arg.get() + " "; + if (_is_windows_os()) { + libraries_dirs_args += + fmt::format("/LIBPATH:\"{}\"", arg.get()) + " "; + } else { + libraries_dirs_args += "-L" + arg.get() + " "; + } } std::string libraries_args; for (auto& arg : compile_options["libraries"]) { - libraries_args += "-l" + arg.get() + " "; + if (_is_windows_os()) { + libraries_args += fmt::format("{}.lib", arg.get()) + " "; + } else { + libraries_args += "-l" + arg.get() + " "; + } } std::string passthrough_parameters_args; + std::regex script_regex(R"(--script=[^,]*script\.ld)"); + std::string replacement = + "--script=" + target_dir + k_separator + "script.ld"; for (auto& arg : compile_options["passthrough_args"]) { - std::string arg_str = arg.get(); - std::string target = "script.ld"; - std::string replacement = target_dir; - replacement.append(k_separator).append(target); - size_t pos = arg_str.find(target); - if (pos != std::string::npos) { - arg_str.replace(pos, target.length(), replacement); - } + std::string arg_str = + std::regex_replace(arg.get(), script_regex, replacement); passthrough_parameters_args += arg_str + " "; } - std::string compile_only_arg = compile_only ? "-c" : ""; - - std::string cmd = fmt::format( - "{} {} {} {} {} {} {} {} {} {} -o {}", - compiler, - source_args, - definitions_args, - cflags_args, - include_dirs_args, - passthrough_parameters_args, - ldflags_args, - libraries_args, - libraries_dirs_args, - compile_only_arg, - target_file); + std::string compile_only_arg = + compile_only ? (_is_windows_os() ? "/c" : "-c") : ""; + + std::string cmd; + if (_is_windows_os()) { + cmd = normalize_path_separator(fmt::format( + "{} {} {} {} {} {} /LD /Fe{} {} /link {} {} {}", + compiler, + include_dirs_args, + definitions_args, + cflags_args, + source_args, + passthrough_parameters_args, + target_file, + compile_only_arg, + libraries_dirs_args, + libraries_args, + ldflags_args)); + } else { + cmd = normalize_path_separator(fmt::format( + "{} {} {} {} {} {} {} {} {} {} -o {}", + compiler, + source_args, + definitions_args, + cflags_args, + include_dirs_args, + passthrough_parameters_args, + ldflags_args, + libraries_args, + libraries_dirs_args, + compile_only_arg, + target_file)); + } return std::make_tuple(cmd, target_file); } @@ -332,8 +411,6 @@ std::unordered_set find_model_names( // Escape the separator if it's backslash (needed for regex) std::string sep = k_separator; - if (sep == "\\") - sep = "\\\\"; std::string pattern = "data" + sep + "aotinductor" + sep + "([^" + sep + "]+)" + sep; @@ -364,6 +441,69 @@ void AOTIModelPackageLoader::load_metadata(const std::string& cpp_filename) { } } +class RAIIMinizArchive { + public: + RAIIMinizArchive(const std::string& zip_path) { + mz_zip_zero_struct(&_zip_archive); + if (!mz_zip_reader_init_file(&_zip_archive, zip_path.c_str(), 0)) { + throw std::runtime_error(fmt::format( + "Failed to initialize zip archive: {}", + mz_zip_get_error_string(mz_zip_get_last_error(&_zip_archive)))); + } + } + RAIIMinizArchive(const RAIIMinizArchive&) = delete; + RAIIMinizArchive& operator=(const RAIIMinizArchive&) = delete; + RAIIMinizArchive(RAIIMinizArchive&&) noexcept = delete; + RAIIMinizArchive& operator=(RAIIMinizArchive&&) noexcept = delete; + ~RAIIMinizArchive() { + // Unconditionally close the file. We can't handle any errors here without + // terminating the program. + mz_zip_reader_end(&_zip_archive); + } + + std::vector get_filenames() { + const unsigned num_zip_files{mz_zip_reader_get_num_files(&_zip_archive)}; + std::vector zip_filenames{}; + zip_filenames.reserve(num_zip_files); + + for (unsigned i{0}; i < num_zip_files; ++i) { + // filename_buf_size == 0 returns the filename length, including null + // terminator + const auto zip_filename_len{ + mz_zip_reader_get_filename(&_zip_archive, i, nullptr, 0)}; + if (!zip_filename_len) { + throw std::runtime_error( + fmt::format("Failed to read zip filename length at index {}", i)); + } + // std::string implicitly appends a character for the null terminator + std::string zip_filename(zip_filename_len - 1, '\0'); + if (!mz_zip_reader_get_filename( + &_zip_archive, i, zip_filename.data(), zip_filename_len)) { + throw std::runtime_error( + fmt::format("Failed to read zip filename at index {}", i)); + } + zip_filenames.emplace_back(zip_filename); + } + + return zip_filenames; + } + + void extract_file( + const std::string& zip_filename, + const std::string& dest_filename) { + if (!mz_zip_reader_extract_file_to_file( + &_zip_archive, zip_filename.c_str(), dest_filename.c_str(), 0)) { + throw std::runtime_error(fmt::format( + "Failed to extract zip file {} to destination file {}", + zip_filename, + dest_filename)); + } + } + + private: + mz_zip_archive _zip_archive{}; +}; + AOTIModelPackageLoader::AOTIModelPackageLoader( const std::string& model_package_path, const std::string& model_name, @@ -383,32 +523,8 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( } // Extract all files within the zipfile to a temporary directory - mz_zip_archive zip_archive; - memset(&zip_archive, 0, sizeof(zip_archive)); - - if (!mz_zip_reader_init_file(&zip_archive, model_package_path.c_str(), 0)) { - throw std::runtime_error( - std::string("Failed to initialize zip archive: ") + - mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive))); - } - - std::vector found_filenames; - for (uint32_t i = 0; i < zip_archive.m_total_files; i++) { - uint32_t filename_len = - mz_zip_reader_get_filename(&zip_archive, i, nullptr, 0); - if (filename_len == 0) { - throw std::runtime_error("Failed to read filename"); - } - // filename_len returned by mz_zip_reader_get_filename includes the null - // terminator, so we need to subtract 1 here - std::string filename_str(filename_len - 1, '\0'); - if (!mz_zip_reader_get_filename( - &zip_archive, i, filename_str.data(), filename_len)) { - throw std::runtime_error("Failed to read filename"); - } - found_filenames.push_back(filename_str); - } - + RAIIMinizArchive zip_archive{model_package_path}; + auto found_filenames{zip_archive.get_filenames()}; if (found_filenames.empty()) { throw std::runtime_error("No files found in zip archive."); } @@ -430,32 +546,36 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( << found_filenames[1]; } - temp_dir_ = create_temp_dir(); + temp_dir_ = normalize_path_separator(create_temp_dir()); std::string so_filename; std::string cpp_filename; std::vector obj_filenames; - std::string model_directory = file_prefix + "data" + k_separator + - "aotinductor" + k_separator + model_name; - std::string const_directory = - file_prefix + "data" + k_separator + "constants"; - - for (const std::string& filename_str : found_filenames) { + std::string model_directory = normalize_path_separator( + file_prefix + "data" + k_separator + "aotinductor" + k_separator + + model_name); + std::string const_directory = normalize_path_separator( + file_prefix + "data" + k_separator + "constants"); + + // zip_filename_str can't be normalize_path_separator, because it should be + // as index for mz_zip_reader_extract_file_to_file. + for (auto const& zip_filename_str : found_filenames) { + auto cur_filename = normalize_path_separator(zip_filename_str); // Only compile files in the specified model directory - if (c10::starts_with(filename_str, model_directory) || - c10::starts_with(filename_str, const_directory)) { + if (c10::starts_with(cur_filename, model_directory) || + c10::starts_with(cur_filename, const_directory)) { std::string output_path_str = temp_dir_; - if (c10::starts_with(filename_str, model_directory)) { + if (c10::starts_with(cur_filename, model_directory)) { output_path_str += k_separator; - output_path_str += filename_str; - } else { // startsWith(filename_str, const_directory) + output_path_str += cur_filename; + } else { // startsWith(zip_filename_str, const_directory) // Extract constants to the same directory as the rest of the files // to be consistent with internal implementation - size_t lastSlash = filename_str.find_last_of(k_separator); - std::string filename = filename_str; + size_t lastSlash = cur_filename.find_last_of(k_separator); + std::string filename = cur_filename; if (lastSlash != std::string::npos) { - filename = filename_str.substr(lastSlash + 1); + filename = cur_filename.substr(lastSlash + 1); } output_path_str.append(k_separator) .append(model_directory) @@ -463,16 +583,17 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( .append(filename); } - LOG(INFO) << "Extract file: " << filename_str << " to " - << output_path_str; + std::string output_file_path = normalize_path_separator(output_path_str); + LOG(INFO) << "Extract file: " << zip_filename_str << " to " + << output_file_path; // Create the parent directory if it doesn't exist - size_t parent_path_idx = output_path_str.find_last_of(k_separator); + size_t parent_path_idx = output_file_path.find_last_of(k_separator); if (parent_path_idx == std::string::npos) { throw std::runtime_error( - "Failed to find parent path in " + output_path_str); + "Failed to find parent path in " + output_file_path); } - std::string parent_path = output_path_str.substr(0, parent_path_idx); + std::string parent_path = output_file_path.substr(0, parent_path_idx); if (!recursive_mkdir(parent_path)) { throw std::runtime_error(fmt::format( "Failed to create directory {}: {}", @@ -481,32 +602,23 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( } // Extracts file to the temp directory - mz_zip_reader_extract_file_to_file( - &zip_archive, filename_str.c_str(), output_path_str.c_str(), 0); + zip_archive.extract_file(zip_filename_str, output_path_str); // Save the file for bookkeeping - size_t extension_idx = output_path_str.find_last_of('.'); + size_t extension_idx = output_file_path.find_last_of('.'); if (extension_idx != std::string::npos) { - std::string filename_extension = output_path_str.substr(extension_idx); + std::string filename_extension = output_file_path.substr(extension_idx); if (filename_extension == ".cpp") { - cpp_filename = output_path_str; - } else if (filename_extension == ".o") { - obj_filenames.push_back(output_path_str); - } else if (filename_extension == ".so") { - so_filename = output_path_str; + cpp_filename = output_file_path; + } else if (filename_extension == object_file_ext()) { + obj_filenames.push_back(output_file_path); + } else if (filename_extension == extension_file_ext()) { + so_filename = output_file_path; } } } } - // Close the zip archive as we have extracted all files to the temp - // directory - if (!mz_zip_reader_end(&zip_archive)) { - throw std::runtime_error( - std::string("Failed to close zip archive: {}") + - mz_zip_get_error_string(mz_zip_get_last_error(&zip_archive))); - } - if (cpp_filename.empty() && so_filename.empty()) { std::string found_filenames_str; for (const std::string& filename : found_filenames) { diff --git a/torch/csrc/inductor/aoti_runtime/interface.h b/torch/csrc/inductor/aoti_runtime/interface.h index f2b29049bb81..fab9a87a725e 100644 --- a/torch/csrc/inductor/aoti_runtime/interface.h +++ b/torch/csrc/inductor/aoti_runtime/interface.h @@ -6,6 +6,17 @@ // applies to other files under torch/csrc/inductor/aoti_runtime/. #include +#ifdef _WIN32 +/* +On Windows, we need to explicit declaration for export APIs. And because the +package loader call these API via GetProcAddress(ldsym on Linux), we can ignore +the import case. +*/ +#define AOTI_API __declspec(dllexport) +#else +#define AOTI_API __attribute__((__visibility__("default"))) +#endif + extern "C" { struct AOTInductorModelOpaque; using AOTInductorModelHandle = AOTInductorModelOpaque*; @@ -21,7 +32,7 @@ using AOTInductorConstantMapHandle = AOTInductorConstantMap*; // TODO: Deprecate this API. This was kept for BC compatibility. // Please use AOTInductorModelContainerCreateWithDevice instead. -AOTIRuntimeError AOTInductorModelContainerCreate( +AOTI_API AOTIRuntimeError AOTInductorModelContainerCreate( AOTInductorModelContainerHandle* container_handle, size_t num_models, bool is_cpu, @@ -34,18 +45,18 @@ AOTIRuntimeError AOTInductorModelContainerCreate( // "cpu", "cuda", "cuda:0", etc. If the device index is not specified for CUDA // device, runtime will use the device index returned by // "cudaGetDevice(&device_idx)" -AOTIRuntimeError AOTInductorModelContainerCreateWithDevice( +AOTI_API AOTIRuntimeError AOTInductorModelContainerCreateWithDevice( AOTInductorModelContainerHandle* container_handle, size_t num_models, const char* device_str, const char* cubin_dir); // Deletes the AOTInductor model container. -AOTIRuntimeError AOTInductorModelContainerDelete( +AOTI_API AOTIRuntimeError AOTInductorModelContainerDelete( AOTInductorModelContainerHandle container_handle); // Runs the inference. -AOTIRuntimeError AOTInductorModelContainerRun( +AOTI_API AOTIRuntimeError AOTInductorModelContainerRun( AOTInductorModelContainerHandle container_handle, AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles // are stolen; the array itself is borrowed @@ -59,7 +70,7 @@ AOTIRuntimeError AOTInductorModelContainerRun( AOTIProxyExecutorHandle proxy_executor_handle); // Single-threaded variant of previous. -AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded( +AOTI_API AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded( AOTInductorModelContainerHandle container_handle, AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles // are stolen; the array itself is borrowed @@ -73,14 +84,14 @@ AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded( AOTIProxyExecutorHandle proxy_executor_handle); // Retrieves the number of constants for the model. -AOTIRuntimeError AOTInductorModelContainerGetNumConstants( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetNumConstants( AOTInductorModelContainerHandle container_handle, size_t* num_constants); // Retrieves a constant's name. // idx is the index of the internal's constants. // Need idx < num_constants from AOTInductorModelContainerGetNumConstants -AOTIRuntimeError AOTInductorModelContainerGetConstantName( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantName( AOTInductorModelContainerHandle container_handle, size_t idx, const char** name); @@ -88,7 +99,7 @@ AOTIRuntimeError AOTInductorModelContainerGetConstantName( // Retrieves a constant's original FQN. // idx is the index of the internal's constants. // Need idx < num_constants from AOTInductorModelContainerGetNumConstants -AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN( AOTInductorModelContainerHandle container_handle, size_t idx, const char** original_fqn); @@ -96,7 +107,7 @@ AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN( // Retrieves whether a constant is from folded. // idx is the index of the internal's constants. // Need idx < num_constants from AOTInductorModelContainerGetNumConstants -AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded( AOTInductorModelContainerHandle container_handle, size_t idx, bool* from_folded); @@ -104,7 +115,7 @@ AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded( // Retrieves the inductor constant type. // idx is the index of the internal's constants. // Need idx < num_constants from AOTInductorModelContainerGetNumConstants -AOTIRuntimeError AOTInductorModelContainerGetConstantType( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantType( AOTInductorModelContainerHandle container_handle, size_t idx, int32_t* type); @@ -112,7 +123,7 @@ AOTIRuntimeError AOTInductorModelContainerGetConstantType( // Retrieves a constant's dtype. // idx is the index of the internal's constants. // Need idx < num_constants from AOTInductorModelContainerGetNumConstants -AOTIRuntimeError AOTInductorModelContainerGetConstantDtype( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantDtype( AOTInductorModelContainerHandle container_handle, size_t idx, int32_t* dtype); @@ -120,20 +131,21 @@ AOTIRuntimeError AOTInductorModelContainerGetConstantDtype( // Retrieves a constant's data size. // idx is the index of the internal's constants. // Need idx < num_constants from AOTInductorModelContainerGetNumConstants -AOTIRuntimeError AOTInductorModelContainerGetConstantDataSize( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetConstantDataSize( AOTInductorModelContainerHandle container_handle, size_t idx, size_t* data_size); // Extract the constants that is being used in the container. -AOTIRuntimeError AOTInductorModelContainerExtractConstantsMap( +AOTI_API AOTIRuntimeError AOTInductorModelContainerExtractConstantsMap( AOTInductorModelContainerHandle container_handle, AOTInductorConstantMapHandle constant_map_handle, bool use_inactive); // Setup the constant buffer in model container with provided ConstantMap. // The ConstantMap is user managed, and the user would retain ownership. -AOTIRuntimeError AOTInductorModelContainerUpdateUserManagedConstantBuffer( +AOTI_API AOTIRuntimeError +AOTInductorModelContainerUpdateUserManagedConstantBuffer( AOTInductorModelContainerHandle container_handle, AOTInductorConstantMapHandle constant_map_handle, bool use_inactive, @@ -142,7 +154,7 @@ AOTIRuntimeError AOTInductorModelContainerUpdateUserManagedConstantBuffer( // Setup the constant buffer in model container with provided ConstantMap // use_inactive should be set as true if the inactive buffer is to be updated. // validate_full_update checks if all constants are included in the ConstantMap -AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer( +AOTI_API AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer( AOTInductorModelContainerHandle container_handle, AOTInductorConstantMapHandle constant_map_handle, bool use_inactive, @@ -150,43 +162,43 @@ AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer( // Setup the inactive constant buffer in model container with provided // ConstantMap -AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer( +AOTI_API AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer( AOTInductorModelContainerHandle container_handle, AOTInductorConstantMapHandle constant_map_handle); // Free the inactive constant buffer in model container. -AOTIRuntimeError AOTInductorModelContainerFreeInactiveConstantBuffer( +AOTI_API AOTIRuntimeError AOTInductorModelContainerFreeInactiveConstantBuffer( AOTInductorModelContainerHandle container_handle); // Run constant folding on constant buffer. -AOTIRuntimeError AOTInductorModelContainerRunConstantFolding( +AOTI_API AOTIRuntimeError AOTInductorModelContainerRunConstantFolding( AOTInductorModelContainerHandle container_handle, bool use_inactive, AOTInductorStreamHandle stream_handle, AOTIProxyExecutorHandle proxy_executor_handle); // Swap the constant buffer being used to the inactive one. -AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer( +AOTI_API AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer( AOTInductorModelContainerHandle container_handle); // Retrieves the number of inputs for the model. -AOTIRuntimeError AOTInductorModelContainerGetNumInputs( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetNumInputs( AOTInductorModelContainerHandle container_handle, size_t* ret_num_inputs); // Retrieves the input name at the given index. -AOTIRuntimeError AOTInductorModelContainerGetInputName( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetInputName( AOTInductorModelContainerHandle container_handle, size_t input_idx, const char** ret_input_names); // Retrieves the number of outputs for the model. -AOTIRuntimeError AOTInductorModelContainerGetNumOutputs( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetNumOutputs( AOTInductorModelContainerHandle container_handle, size_t* ret_num_outputs); // Retrieves the output name at the given index. -AOTIRuntimeError AOTInductorModelContainerGetOutputName( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetOutputName( AOTInductorModelContainerHandle container_handle, size_t output_idx, const char** ret_output_names); @@ -198,31 +210,32 @@ AOTIRuntimeError AOTInductorModelContainerGetOutputName( // // constant_map_handle is an opaque type to satisfy the C ABI. It should be a // std::unordered_map*. -AOTIRuntimeError AOTInductorModelCreate( +AOTI_API AOTIRuntimeError AOTInductorModelCreate( AOTInductorModelHandle* model_handle, AOTInductorConstantMapHandle constant_map_handle); // Run an AOTInductorModel (see AOTInductorModelCreate for when one should use // this function versus AOTInductorModelContainerRun). -AOTIRuntimeError AOTInductorModelRun( +AOTI_API AOTIRuntimeError AOTInductorModelRun( AOTInductorModelHandle model_handle, AtenTensorHandle* input_handles, AtenTensorHandle* output_handles); // Replace AOTInductorModel's constant map. Note it doesn't handle concurrency // so be sure to handle ordering if AOTInductorModelRun is ran concurrently. -AOTIRuntimeError AOTInductorModelUpdateConstantsMap( +AOTI_API AOTIRuntimeError AOTInductorModelUpdateConstantsMap( AOTInductorModelHandle model_handle, AOTInductorConstantMapHandle constant_map_handle); // Delete an AOTInductorModel created by AOTInductorModelCreate. -AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle); +AOTI_API AOTIRuntimeError +AOTInductorModelDelete(AOTInductorModelHandle model_handle); -AOTIRuntimeError AOTInductorModelGetNumOutputs( +AOTI_API AOTIRuntimeError AOTInductorModelGetNumOutputs( AOTInductorModelHandle model_handle, size_t* ret_num_outputs); -AOTIRuntimeError AOTInductorModelContainerGetCallSpec( +AOTI_API AOTIRuntimeError AOTInductorModelContainerGetCallSpec( AOTInductorModelContainerHandle container_handle, const char** in_spec, const char** out_spec); diff --git a/torch/csrc/inductor/aoti_runtime/model_base.h b/torch/csrc/inductor/aoti_runtime/model_base.h index 9eac761b7ef8..6e80c90499a0 100644 --- a/torch/csrc/inductor/aoti_runtime/model_base.h +++ b/torch/csrc/inductor/aoti_runtime/model_base.h @@ -1,9 +1,15 @@ #pragma once +#ifdef _WIN32 +#include +#include // std::function +#else #include -#include #include #include +#endif + +#include #include #include #include diff --git a/torch/csrc/inductor/aoti_runtime/model_container.h b/torch/csrc/inductor/aoti_runtime/model_container.h index 416c186a3ae0..0bd12e841e39 100644 --- a/torch/csrc/inductor/aoti_runtime/model_container.h +++ b/torch/csrc/inductor/aoti_runtime/model_container.h @@ -467,14 +467,34 @@ class AOTInductorModelContainer { constants_blob_ptr + constants_internal_offset_[idx]; void* user_constant_ptr; int64_t constant_size; + int64_t* stride; + int64_t offset; aoti_torch_get_data_ptr(tensor, &user_constant_ptr); aoti_torch_get_storage_size(tensor, &constant_size); + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(tensor, &stride)); + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_storage_offset(tensor, &offset)); + auto dtype = models_[0]->constant_dtype(idx); + #ifdef USE_XPU sycl::queue* queue_ptr = nullptr; aoti_torch_get_current_sycl_queue((void**)&queue_ptr); queue_ptr ->memcpy(internal_constants_ptr, user_constant_ptr, constant_size) .wait(); +#elif USE_MPS + internal_constants_ptr = constants_blob_ptr; + aoti_torch_mps_copy_buffer( + user_constant_ptr, + constants_blob_ptr, + constant_size, + offset, + constants_internal_offset_[idx]); + // For mps tensors, all constants are stored in one buffer, with the + // offset being where the constant starts. So we want to change the + // constant tensor's offset to point to constants_internal_offset_[idx] + offset = constants_internal_offset_[idx] / + aoti_torch_dtype_element_size(dtype); #elif USE_CUDA AOTI_RUNTIME_CUDA_CHECK(cudaMemcpy( internal_constants_ptr, @@ -488,20 +508,15 @@ class AOTInductorModelContainer { // We extract stride and offset from provided Tensor since we do not // guarantee that the tensor is contiguous. AtenTensorHandle tensor_handle; - int64_t* stride; - int64_t offset; int device_type = models_[0]->get_device_type(); int device_idx = models_[0]->get_device_idx(); - AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(tensor, &stride)); - AOTI_TORCH_ERROR_CODE_CHECK( - aoti_torch_get_storage_offset(tensor, &offset)); AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob( internal_constants_ptr, models_[0]->constant_ndim(idx), models_[0]->constant_shape(idx), stride, offset, - models_[0]->constant_dtype(idx), + dtype, device_type, device_idx, &tensor_handle)); diff --git a/torch/csrc/inductor/aoti_torch/c/macros.h b/torch/csrc/inductor/aoti_torch/c/macros.h new file mode 100644 index 000000000000..6f1346cdcf86 --- /dev/null +++ b/torch/csrc/inductor/aoti_torch/c/macros.h @@ -0,0 +1,63 @@ +#ifndef AOTI_TORCH_MACRO_H +#define AOTI_TORCH_MACRO_H + +#include +#include +#ifdef __GNUC__ +#define AOTI_TORCH_EXPORT __attribute__((__visibility__("default"))) +#else // !__GNUC__ +#ifdef _WIN32 +// PyTorch2 doesn't currently work on Windows. Exporting these APIs can lead +// to symbol clashes at link time if libtorch is included in a DLL and binary +// that depends on the DLL. As a short term fix, we don't export the symbols. +// In the long term, this will need to be addressed when Windows is supported. +#ifdef OVRSOURCE +// Do not export AOTI on Windows for internal builds +#define AOTI_TORCH_EXPORT +#else /* OVRSOURCE */ +#ifdef EXPORT_AOTI_FUNCTIONS +#define AOTI_TORCH_EXPORT __declspec(dllexport) +#else +#define AOTI_TORCH_EXPORT __declspec(dllimport) +#endif +#endif /* OVRSOURCE */ +#else // !_WIN32 +#define AOTI_TORCH_EXPORT +#endif // _WIN32 +#endif // __GNUC__ + +#ifdef __cplusplus +extern "C" { +#endif +// AtenTensorHandle represents an abstract notion of Tensor that can be passed +// between model.so and libtorch.so. The contents of the structure itself +// are private; model.so is not allowed to access any fields directly, it must +// go through functions defined in this ABI. Under the hood, this is +// represented as at::Tensor*, but we reserve the right to change this (and in +// fact, we probably should change it to at::TensorImpl* at least). +// +// An AtenTensorHandle can be owning (please check the API reference for exact +// ownership/borrow semantics). If you have an owning AtenTensorHandle +// in model.so, you are obligated to aoti_torch_delete_tensor_object when you +// are done. You can use the helper C++ class RAIIAtenTensorHandle +// (see aot_runtime/model.h) to ensure the deallocator is called in RAII style +// (note that RAIIAtenTensorHandle is private to model.so, and never crosses +// the ABI boundary.) +struct AtenTensorOpaque; +using AtenTensorHandle = AtenTensorOpaque*; + +struct AtenGeneratorOpaque; +using AtenGeneratorHandle = AtenGeneratorOpaque*; + +struct AOTIProxyExecutorOpaque; +using AOTIProxyExecutorHandle = AOTIProxyExecutorOpaque*; + +using AOTITorchError = int32_t; +#define AOTI_TORCH_SUCCESS 0 +#define AOTI_TORCH_FAILURE 1 + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // AOTI_TORCH_MACRO_H diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 6a23c9d465c7..9d512ce1f481 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -1,8 +1,8 @@ #ifndef AOTI_TORCH_SHIM #define AOTI_TORCH_SHIM -#include -#include +#include +#include // This header defines a stable C API for certain ATen functionality in // libtorch. The AOTInductor compiled model.so will only refer to this header @@ -36,29 +36,6 @@ // maintain the old and new versions of the APIs until all old model.so // go out of use. -#ifdef __GNUC__ -#define AOTI_TORCH_EXPORT __attribute__((__visibility__("default"))) -#else // !__GNUC__ -#ifdef _WIN32 -// PyTorch2 doesn't currently work on Windows. Exporting these APIs can lead -// to symbol clashes at link time if libtorch is included in a DLL and binary -// that depends on the DLL. As a short term fix, we don't export the symbols. -// In the long term, this will need to be addressed when Windows is supported. -#ifdef OVRSOURCE -// Do not export AOTI on Windows for internal builds -#define AOTI_TORCH_EXPORT -#else /* OVRSOURCE */ -#ifdef EXPORT_AOTI_FUNCTIONS -#define AOTI_TORCH_EXPORT __declspec(dllexport) -#else -#define AOTI_TORCH_EXPORT __declspec(dllimport) -#endif -#endif /* OVRSOURCE */ -#else // !_WIN32 -#define AOTI_TORCH_EXPORT -#endif // _WIN32 -#endif // __GNUC__ - // The following files are implemented in a header-only way and are guarded by // test/cpp/aoti_abi_check #include @@ -69,33 +46,6 @@ extern "C" { #endif -// AtenTensorHandle represents an abstract notion of Tensor that can be passed -// between model.so and libtorch.so. The contents of the structure itself -// are private; model.so is not allowed to access any fields directly, it must -// go through functions defined in this ABI. Under the hood, this is -// represented as at::Tensor*, but we reserve the right to change this (and in -// fact, we probably should change it to at::TensorImpl* at least). -// -// An AtenTensorHandle can be owning (please check the API reference for exact -// ownership/borrow semantics). If you have an owning AtenTensorHandle -// in model.so, you are obligated to aoti_torch_delete_tensor_object when you -// are done. You can use the helper C++ class RAIIAtenTensorHandle -// (see aot_runtime/model.h) to ensure the deallocator is called in RAII style -// (note that RAIIAtenTensorHandle is private to model.so, and never crosses -// the ABI boundary.) -struct AtenTensorOpaque; -using AtenTensorHandle = AtenTensorOpaque*; - -struct AtenGeneratorOpaque; -using AtenGeneratorHandle = AtenGeneratorOpaque*; - -struct AOTIProxyExecutorOpaque; -using AOTIProxyExecutorHandle = AOTIProxyExecutorOpaque*; - -using AOTITorchError = int32_t; -#define AOTI_TORCH_SUCCESS 0 -#define AOTI_TORCH_FAILURE 1 - // Getter functions for retrieving various constants from the runtime, that // can subsequently be passed to other aoti_* functions. By hiding these // behind functions, the precise value of device/dtype is NOT part of the @@ -349,127 +299,6 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_v2( const uint8_t* opaque_metadata, int64_t opaque_metadata_size); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag( - AtenTensorHandle weight, - AtenTensorHandle indices, - AtenTensorHandle offsets, - int32_t scale_grad_by_freq, - int32_t mode, - int32_t sparse, - AtenTensorHandle per_sample_weights, // optional argument - int32_t include_last_offset, - int32_t padding_idx, - AtenTensorHandle* ret0, // returns new reference - AtenTensorHandle* ret1, // returns new reference - AtenTensorHandle* ret2, // returns new reference - AtenTensorHandle* ret3 // returns new reference -); - -AOTI_TORCH_EXPORT AOTITorchError aoti_torch__fft_c2c( - AtenTensorHandle self, - const int64_t* dim_ptr, - int64_t dim_size, - int64_t normalization, - int32_t forward, - AtenTensorHandle* ret // returns new reference -); - -// This version is deprecated. We will remove it later -AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention( - AtenTensorHandle query, - AtenTensorHandle key, - AtenTensorHandle value, - double dropout_p, - bool is_causal, - bool return_debug_mask, - double scale, - AtenTensorHandle* ret0, // returns new reference - AtenTensorHandle* ret1, // returns new reference - AtenTensorHandle* ret2, // returns new reference - AtenTensorHandle* ret3, // returns new reference - int64_t* ret4, - int64_t* ret5, - AtenTensorHandle* ret6, // returns new reference - AtenTensorHandle* ret7, // returns new reference - AtenTensorHandle* ret8 // returns new reference -); - -AOTI_TORCH_EXPORT AOTITorchError -aoti_torch__scaled_dot_product_flash_attention_v2( - AtenTensorHandle query, - AtenTensorHandle key, - AtenTensorHandle value, - double dropout_p, - int is_causal, - int return_debug_mask, - double* scale, // optional argument - AtenTensorHandle* ret0, // returns new reference - AtenTensorHandle* ret1, // returns new reference - AtenTensorHandle* ret2, // returns new reference - AtenTensorHandle* ret3, // returns new reference - int64_t* ret4, - int64_t* ret5, - AtenTensorHandle* ret6, // returns new reference - AtenTensorHandle* ret7, // returns new reference - AtenTensorHandle* ret8 // returns new reference -); - -AOTI_TORCH_EXPORT AOTITorchError -aoti_torch__scaled_dot_product_efficient_attention( - AtenTensorHandle query, - AtenTensorHandle key, - AtenTensorHandle value, - AtenTensorHandle attn_bias, // optional argument - int compute_log_sumexp, - double dropout_p, - int is_causal, - double* scale, // optional argument - AtenTensorHandle* ret0, // returns new reference - AtenTensorHandle* ret1, // returns new reference - AtenTensorHandle* ret2, // returns new reference - AtenTensorHandle* ret3 // returns new reference -); - -AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm( - AtenTensorHandle self, - AtenTensorHandle mat2, - AtenTensorHandle bias, - int32_t* out_dtype, - AtenTensorHandle scale_a, - AtenTensorHandle scale_b, - AtenTensorHandle scale_result, - int8_t use_fast_accum, - AtenTensorHandle* ret0, - AtenTensorHandle* ret1); - -AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm_v2( - AtenTensorHandle self, - AtenTensorHandle mat2, - AtenTensorHandle scale_a, - AtenTensorHandle scale_b, - AtenTensorHandle bias, - AtenTensorHandle scale_result, - int32_t* out_dtype, - int8_t use_fast_accum, - AtenTensorHandle* ret0); - -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_convolution( - AtenTensorHandle input, - AtenTensorHandle weight, - AtenTensorHandle bias, // optional argument - const int64_t* stride_ptr, - int64_t stride_size, - const int64_t* padding_ptr, - int64_t padding_size, - const int64_t* dilation_ptr, - int64_t dilation_size, - int transposed, - const int64_t* output_padding_ptr, - int64_t output_padding_size, - int64_t groups, - AtenTensorHandle* ret // returns new reference -); - // This function will create a new uninitialized tensor object // and its pointer is returned through *ret. AOTI_TORCH_EXPORT AOTITorchError @@ -502,29 +331,11 @@ aoti_torch_clone(AtenTensorHandle self, AtenTensorHandle* ret); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_clone_preserve_strides(AtenTensorHandle self, AtenTensorHandle* ret); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_addmm_out( - AtenTensorHandle out, - AtenTensorHandle self, - AtenTensorHandle mat1, - AtenTensorHandle mat2, - float beta, - float alpha); - -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_bmm_out( - AtenTensorHandle out, - AtenTensorHandle self, - AtenTensorHandle mat2); - AOTI_TORCH_EXPORT AOTITorchError aoti_torch_copy_( AtenTensorHandle self, AtenTensorHandle src, int32_t non_blocking); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mm_out( - AtenTensorHandle out, - AtenTensorHandle self, - AtenTensorHandle mat2); - AOTI_TORCH_EXPORT AOTITorchError aoti_torch__mm_plus_mm_out( AtenTensorHandle out, AtenTensorHandle a, @@ -554,7 +365,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight( AtenTensorHandle input, AtenTensorHandle weight, - AtenTensorHandle bias, + AtenTensorHandle bias, // optional argument int64_t out_channel, AtenTensorHandle* out); @@ -571,16 +382,8 @@ aoti_torch_cpu__wrapped_quantized_linear_prepacked( int64_t out_channel, AtenTensorHandle* out); -AOTI_TORCH_EXPORT AOTITorchError -aoti_torch_nonzero(AtenTensorHandle self, AtenTensorHandle* out); - AOTI_TORCH_EXPORT AOTITorchError aoti_torch_zero_(AtenTensorHandle self); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor( - AtenTensorHandle repeats, - int64_t* output_size, - AtenTensorHandle* out); - AOTI_TORCH_EXPORT AOTITorchError aoti_torch_check_inf_and_nan(const char* tensor_name, AtenTensorHandle tensor); @@ -608,17 +411,6 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_index_put_out( const AtenTensorHandle values, bool accumulate); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_as_real( - AtenTensorHandle self, - AtenTensorHandle* ret // returns new reference -); - -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_dtype( - AtenTensorHandle self, - int32_t dtype, - AtenTensorHandle* ret // returns new reference -); - AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle( AtenTensorHandle self, const char* msg); diff --git a/torch/csrc/inductor/aoti_torch/c/shim_deprecated.h b/torch/csrc/inductor/aoti_torch/c/shim_deprecated.h new file mode 100644 index 000000000000..964db6b0076c --- /dev/null +++ b/torch/csrc/inductor/aoti_torch/c/shim_deprecated.h @@ -0,0 +1,199 @@ +#ifndef AOTI_TORCH_SHIM_DEPRECATED +#define AOTI_TORCH_SHIM_DEPRECATED + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +[[deprecated( + "aoti_torch__embedding_bag is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag( + AtenTensorHandle weight, + AtenTensorHandle indices, + AtenTensorHandle offsets, + int32_t scale_grad_by_freq, + int32_t mode, + int32_t sparse, + AtenTensorHandle per_sample_weights, // optional argument + int32_t include_last_offset, + int32_t padding_idx, + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3 // returns new reference +); + +[[deprecated( + "aoti_torch__fft_c2c is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__fft_c2c( + AtenTensorHandle self, + const int64_t* dim_ptr, + int64_t dim_size, + int64_t normalization, + int32_t forward, + AtenTensorHandle* ret // returns new reference +); + +[[deprecated( + "aoti_torch__scaled_mm is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm( + AtenTensorHandle self, + AtenTensorHandle mat2, + AtenTensorHandle bias, + int32_t* out_dtype, + AtenTensorHandle scale_a, + AtenTensorHandle scale_b, + AtenTensorHandle scale_result, + int8_t use_fast_accum, + AtenTensorHandle* ret0, + AtenTensorHandle* ret1); + +[[deprecated( + "aoti_torch__scaled_mm_v2 is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm_v2( + AtenTensorHandle self, + AtenTensorHandle mat2, + AtenTensorHandle scale_a, + AtenTensorHandle scale_b, + AtenTensorHandle bias, + AtenTensorHandle scale_result, + int32_t* out_dtype, + int8_t use_fast_accum, + AtenTensorHandle* ret0); + +[[deprecated( + "aoti_torch_addmm_out is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_addmm_out( + AtenTensorHandle out, + AtenTensorHandle self, + AtenTensorHandle mat1, + AtenTensorHandle mat2, + float beta, + float alpha); + +[[deprecated( + "aoti_torch_bmm is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_bmm_out( + AtenTensorHandle out, + AtenTensorHandle self, + AtenTensorHandle mat2); + +[[deprecated( + "aoti_torch_convolution is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_convolution( + AtenTensorHandle input, + AtenTensorHandle weight, + AtenTensorHandle bias, // optional argument + const int64_t* stride_ptr, + int64_t stride_size, + const int64_t* padding_ptr, + int64_t padding_size, + const int64_t* dilation_ptr, + int64_t dilation_size, + int transposed, + const int64_t* output_padding_ptr, + int64_t output_padding_size, + int64_t groups, + AtenTensorHandle* ret // returns new reference +); + +[[deprecated( + "aoti_torch_mm_out is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mm_out( + AtenTensorHandle out, + AtenTensorHandle self, + AtenTensorHandle mat2); + +[[deprecated( + "aoti_torch_nonzero is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_nonzero(AtenTensorHandle self, AtenTensorHandle* out); + +[[deprecated( + "aoti_torch_repeat_interleave_Tensor is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor( + AtenTensorHandle repeats, + int64_t* output_size, + AtenTensorHandle* out); + +[[deprecated( + "aoti_torch_view_as_real is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_as_real( + AtenTensorHandle self, + AtenTensorHandle* ret // returns new reference +); + +[[deprecated( + "aoti_torch_view_dtype is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_dtype( + AtenTensorHandle self, + int32_t dtype, + AtenTensorHandle* ret // returns new reference +); + +[[deprecated( + "aoti_torch__scaled_dot_product_flash_attention is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention( + AtenTensorHandle query, + AtenTensorHandle key, + AtenTensorHandle value, + double dropout_p, + bool is_causal, + bool return_debug_mask, + double scale, + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3, // returns new reference + int64_t* ret4, + int64_t* ret5, + AtenTensorHandle* ret6, // returns new reference + AtenTensorHandle* ret7, // returns new reference + AtenTensorHandle* ret8 // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch__scaled_dot_product_flash_attention_v2( + AtenTensorHandle query, + AtenTensorHandle key, + AtenTensorHandle value, + double dropout_p, + int is_causal, + int return_debug_mask, + double* scale, // optional argument + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3, // returns new reference + int64_t* ret4, + int64_t* ret5, + AtenTensorHandle* ret6, // returns new reference + AtenTensorHandle* ret7, // returns new reference + AtenTensorHandle* ret8 // returns new reference +); + +[[deprecated( + "aoti_torch__scaled_dot_product_efficient_attention is deprecated and will be removed in future versions.")]] +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch__scaled_dot_product_efficient_attention( + AtenTensorHandle query, + AtenTensorHandle key, + AtenTensorHandle value, + AtenTensorHandle attn_bias, // optional argument + int compute_log_sumexp, + double dropout_p, + int is_causal, + double* scale, // optional argument + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3 // returns new reference +); + +#ifdef __cplusplus +} // extern "C" + +#endif +#endif // AOTI_TORCH_SHIM_DEPRECATED diff --git a/torch/csrc/inductor/aoti_torch/c/shim_mps.h b/torch/csrc/inductor/aoti_torch/c/shim_mps.h index bd86885de13c..08f1569927f0 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim_mps.h +++ b/torch/csrc/inductor/aoti_torch/c/shim_mps.h @@ -32,6 +32,13 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_memcpy( size_t data_size, uint8_t* constants_start); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_copy_buffer( + void* src_buffer, + void* dst_buffer, + size_t data_size, + size_t src_offset, + size_t dst_offset); + #ifdef __cplusplus } // extern "C" #endif diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index 2aa09cb802ec..aced2b2f539d 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -29,6 +29,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_c2c(AtenTensorHandle self, AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index e0607f984b3d..92d30ded855f 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -32,6 +32,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_backward(AtenT AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* cum_seq_q, AtenTensorHandle* cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* seqused_k, AtenTensorHandle* alibi_slopes, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__int_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_forward(AtenTensorHandle self, double p, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h index a5d654c51884..c76ee685c25d 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h @@ -18,7 +18,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__efficientzerotensor(const int64 AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, int64_t normalized_shape_ndim, AtenTensorHandle weight, double eps, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_mask, double dropout_p, int32_t is_causal, AtenTensorHandle* dropout_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h index 243bfb5fc87a..6fc51bd0c8f8 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h @@ -13,6 +13,7 @@ extern "C" { AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index dc6e52b0c4db..a33198fd1ba0 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -981,16 +981,17 @@ AOTITorchError aoti_torch_cpu__wrapped_linear_prepack( AOTITorchError aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight( AtenTensorHandle input, AtenTensorHandle weight, - AtenTensorHandle bias, + AtenTensorHandle bias, // optional argument int64_t out_channel, AtenTensorHandle* out) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ at::Tensor* input_tensor = tensor_handle_to_tensor_pointer(input); at::Tensor* weight_tensor = tensor_handle_to_tensor_pointer(weight); - at::Tensor* bias_tensor = tensor_handle_to_tensor_pointer(bias); + auto optional_bias_tensor = + pointer_to_optional(tensor_handle_to_tensor_pointer(bias)); *out = new_tensor_handle(at::fbgemm_linear_fp16_weight_fp32_activation( - *input_tensor, *weight_tensor, *bias_tensor)); + *input_tensor, *weight_tensor, optional_bias_tensor)); }); } diff --git a/torch/csrc/inductor/aoti_torch/shim_mps.mm b/torch/csrc/inductor/aoti_torch/shim_mps.mm index 9f70331ffc0b..1bf88839ecfe 100644 --- a/torch/csrc/inductor/aoti_torch/shim_mps.mm +++ b/torch/csrc/inductor/aoti_torch/shim_mps.mm @@ -3,6 +3,8 @@ #include #include #include +#include +#include using namespace torch::aot_inductor; @@ -40,3 +42,16 @@ AOTITorchError aoti_torch_mps_free( memcpy(buffer_pointer + constant_offset, constants_start + bytes_read, data_size); }); } + +AOTITorchError +aoti_torch_mps_copy_buffer(void* src_buffer, void* dst_buffer, size_t data_size, size_t src_offset, size_t dst_offset) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + + auto src_mtl_buffer = (id)src_buffer; + auto dst_mtl_buffer = (id)dst_buffer; + + auto* stream = at::mps::getCurrentMPSStream(); + uint64_t profile_id = at::mps::getMPSProfiler().beginProfileCopy(src_mtl_buffer, dst_mtl_buffer, at::OptionalTensorRef(), at::OptionalTensorRef(), data_size, true); + stream->copy_and_sync(src_mtl_buffer, dst_mtl_buffer, data_size, src_offset, dst_offset, true, profile_id); + }); +} diff --git a/torch/csrc/jit/tensorexpr/ConditionalsInTE.md b/torch/csrc/jit/tensorexpr/ConditionalsInTE.md index 731f70a5b826..c7bcea497648 100644 --- a/torch/csrc/jit/tensorexpr/ConditionalsInTE.md +++ b/torch/csrc/jit/tensorexpr/ConditionalsInTE.md @@ -14,7 +14,7 @@ So far the recommendation was to standardize on fused conditionals. ## Expression Conditionals vs Statement Conditionals -Tensor IR contains both expression conditionals (`CompareSelect` and `IfThenElse`), as well as statement conditionals (`Cond`). Expression conditionals are defined by being functional in nature: there is no side effect from duplicating the conditional, evaluating it twice, etc. They are an important ingredient in expression important operators like ReLU: +Tensor IR contains both expression conditionals (`CompareSelect` and `IfThenElse`), as well as statement conditionals (`Cond`). Expression conditionals are defined by being functional in nature: there is no side effect from duplicating the conditional, evaluating it twice, etc. They are an important ingredient in expressing important operators like ReLU: ``` store (((load A) >= 0.0) ? (load A) : 0.0), B diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index d6c5590a7100..918d82579444 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -83,7 +83,7 @@ using namespace torch::jit::tensorexpr; C10_DEFINE_bool( torch_jit_llvm_use_fast_intrinsics, false, - "Use fast (but slightly less accurate) implementations of tanh and sigmoid"); + "Use fast (but slightly less accurate) implementations of tanh and sigmoid") namespace torch::jit::tensorexpr { @@ -246,7 +246,7 @@ class LLVMCodeGenImpl : public IRVisitor { std::string kernel_func_name_; #define LLVM_TYPE_DECLARE(_1, Name) llvm::Type* Name##Ty_; - AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, LLVM_TYPE_DECLARE); + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, LLVM_TYPE_DECLARE) #undef LLVM_TYPE_DECLARE #if LLVM_VERSION_MAJOR >= 15 @@ -1101,7 +1101,7 @@ std::enable_if_t, llvm::Value*> getFromType( void LLVMCodeGenImpl::visit(const Name##ImmPtr& v) { \ value_ = getFromType(Name##Ty_, v->value()); \ } -AT_FORALL_SCALAR_TYPES(IMM_VISIT_DECLARE); +AT_FORALL_SCALAR_TYPES(IMM_VISIT_DECLARE) #undef IMM_VISIT_DECLARE void LLVMCodeGenImpl::visit(const HalfImmPtr& v) { diff --git a/torch/csrc/jit/tensorexpr/llvm_jit.cpp b/torch/csrc/jit/tensorexpr/llvm_jit.cpp index c9a930576cdc..80d919a5674e 100644 --- a/torch/csrc/jit/tensorexpr/llvm_jit.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_jit.cpp @@ -11,6 +11,7 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override") #include C10_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include #include #include @@ -35,6 +36,7 @@ C10_DIAGNOSTIC_POP() #endif #include #include +C10_DIAGNOSTIC_POP() #include diff --git a/torch/csrc/jit/tensorexpr/llvm_jit.h b/torch/csrc/jit/tensorexpr/llvm_jit.h index beadbdd5e537..19a21329b64a 100644 --- a/torch/csrc/jit/tensorexpr/llvm_jit.h +++ b/torch/csrc/jit/tensorexpr/llvm_jit.h @@ -9,9 +9,11 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wsuggest-override") #include C10_DIAGNOSTIC_POP() +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wextra-semi") #include #include #include +C10_DIAGNOSTIC_POP() #include #include diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 646801fa9a19..7f0888666d3a 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -1843,11 +1843,11 @@ bool LoopNest::hasLoopCarriedDependence(const ForPtr& loop) { auto bLoads = NodeFinder::find(*it2); // ReadAfterWrite for (auto& aStore : aStores) { - for (auto& bLoad : bLoads) { // codespell:ignore + for (auto& bLoad : bLoads) { if (aStore->buf() == bLoad->buf()) { if (!areIndicesLoopIndependent( aStore->indices(), bLoad->indices(), outer_loop_vars)) { - if (isOverlapping(analyzer, aStore, bLoad)) { // codespell:ignore + if (isOverlapping(analyzer, aStore, bLoad)) { return true; } } diff --git a/torch/csrc/profiler/python/init.cpp b/torch/csrc/profiler/python/init.cpp index 92f2f39a5da2..062f87a465cc 100644 --- a/torch/csrc/profiler/python/init.cpp +++ b/torch/csrc/profiler/python/init.cpp @@ -356,10 +356,10 @@ void initPythonBindings(PyObject* module) { " that expose CUDA device, stream and event synchronization activities. This feature is new\n" " and currently disabled by default.\n" " adjust_profiler_step (bool) : whether to adjust the profiler step to\n" - " match the parent python event duration. This feature is new and currently disabled by default.\n", - " disable_external_correlation (bool) : whether to disable external correlation\n", - " profile_all_threads (bool) : whether to profile all threads\n", - " capture_overload_names (bool) : whether to include ATen overload names in the profile\n", + " match the parent python event duration. This feature is new and currently disabled by default.\n" + " disable_external_correlation (bool) : whether to disable external correlation\n" + " profile_all_threads (bool) : whether to profile all threads\n" + " capture_overload_names (bool) : whether to include ATen overload names in the profile\n" " custom_profiler_config (string) : Used to pass some configurations to the custom profiler backend.\n", py::arg("profiler_metrics") = std::vector(), py::arg("profiler_measure_per_kernel") = false, diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h new file mode 100644 index 000000000000..d469abbd55ac --- /dev/null +++ b/torch/csrc/stable/ops.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include +#include + +using torch::stable::Tensor; + +// We expect this to be the stable version of the empty_like op that takes in +// no kwargs (device, dtype, layout, memory_format). We will add kwargs +// support in the future. +inline Tensor empty_like(const Tensor& self) { + const auto num_args = 6; + std::array stack{ + from(self), + from(std::nullopt), + from(std::nullopt), + from(std::nullopt), + from(std::nullopt), + from(std::nullopt)}; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_call_dispatcher("aten::empty_like", "", stack.data())); + return to(stack[0]); +} + +// We expect this to be the stable version of the transpose op with identical +// semantics to the existing transpose.int op. +inline Tensor transpose(const Tensor& self, int64_t dim0, int64_t dim1) { + const auto num_args = 3; + std::array stack{from(self), from(dim0), from(dim1)}; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_call_dispatcher("aten::transpose", "int", stack.data())); + return to(stack[0]); +} + +// We expect this to be the stable version of the zero_ op with identical +// semantics to the existing zero_ op (except that it will not be called as +// a tensor method but only as a function i.e. zero_(t) not t.zero_()). +inline Tensor zero_(Tensor& self) { + const auto num_args = 1; + std::array stack{from(self)}; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_call_dispatcher("aten::zero_", "", stack.data())); + return to(stack[0]); +} diff --git a/torch/csrc/utils/generated_serialization_types.h b/torch/csrc/utils/generated_serialization_types.h index 67fd1ecf05c0..98803390e510 100644 --- a/torch/csrc/utils/generated_serialization_types.h +++ b/torch/csrc/utils/generated_serialization_types.h @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<<110c364974d3b0f7dcbdf6862781212bdcc7178925c43c894c336fc2b6ca6628>> +// checksum<<5c990535d373dcaa291a4f994b4d7b025e0f8e806ca5268085ef699d0e4d3000>> // clang-format off #pragma once @@ -283,6 +283,8 @@ enum class ScalarType { UINT16 = 28, FLOAT8E4M3FN = 29, FLOAT8E5M2 = 30, + FLOAT8E4M3FNUZ = 31, + FLOAT8E5M2FNUZ = 32, }; inline std::string_view printEnum(const ScalarType& e) { @@ -304,6 +306,8 @@ inline std::string_view printEnum(const ScalarType& e) { case ScalarType::UINT16: return "UINT16"; case ScalarType::FLOAT8E4M3FN: return "FLOAT8E4M3FN"; case ScalarType::FLOAT8E5M2: return "FLOAT8E5M2"; + case ScalarType::FLOAT8E4M3FNUZ: return "FLOAT8E4M3FNUZ"; + case ScalarType::FLOAT8E5M2FNUZ: return "FLOAT8E5M2FNUZ"; default: throw std::runtime_error("Unknown enum value"); } @@ -327,6 +331,8 @@ inline void parseEnum(std::string_view s, ScalarType& t) { if (s == "UINT16") { t = ScalarType::UINT16; return; } if (s == "FLOAT8E4M3FN") { t = ScalarType::FLOAT8E4M3FN; return; } if (s == "FLOAT8E5M2") { t = ScalarType::FLOAT8E5M2; return; } + if (s == "FLOAT8E4M3FNUZ") { t = ScalarType::FLOAT8E4M3FNUZ; return; } + if (s == "FLOAT8E5M2FNUZ") { t = ScalarType::FLOAT8E5M2FNUZ; return; } throw std::runtime_error("Unknown enum value: " + std::string{s}); } diff --git a/torch/csrc/utils/python_compat.h b/torch/csrc/utils/python_compat.h index a1537611cc47..16292e4fd030 100644 --- a/torch/csrc/utils/python_compat.h +++ b/torch/csrc/utils/python_compat.h @@ -13,6 +13,7 @@ extern "C" { #define IS_PYTHON_3_12_PLUS PY_VERSION_HEX >= 0x030C0000 #define IS_PYTHON_3_13_PLUS PY_VERSION_HEX >= 0x030D0000 #define IS_PYTHON_3_14_PLUS PY_VERSION_HEX >= 0x030E0000 +#define IS_PYTHON_3_15_PLUS PY_VERSION_HEX >= 0x030F0000 static inline int PyCode_GetNCellvars(PyCodeObject* code) { // gh-26364 added co_ncellvars to Python 3.11.0rc1 diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index fc9d09ce63a6..6a2d62bd424c 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -18,7 +18,7 @@ import traceback import warnings from functools import lru_cache -from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, cast, NewType, Optional, TYPE_CHECKING, Union import torch import torch._C @@ -244,21 +244,27 @@ def _extract_arch_version(arch_string: str) -> int: def _check_capability(): - incorrect_binary_warn = """ - Found GPU%d %s which requires CUDA_VERSION >= %d to - work properly, but your PyTorch was compiled - with CUDA_VERSION %d. Please install the correct PyTorch binary - using instructions from https://pytorch.org - """ # noqa: F841 - - old_gpu_warn = """ + incompatible_gpu_warn = """ Found GPU%d %s which is of cuda capability %d.%d. - PyTorch no longer supports this GPU because it is too old. - The minimum cuda capability supported by this library is %d.%d. + Minimum and Maximum cuda capability supported by this version of PyTorch is + (%d.%d) - (%d.%d) """ + matched_cuda_warn = """ + Please install PyTorch with a following CUDA + configurations: {} following instructions at + https://pytorch.org/get-started/locally/ + """ + + # Binary CUDA_ARCHES SUPPORTED by PyTorch + CUDA_ARCHES_SUPPORTED = { + "12.6": {"min": 50, "max": 90}, + "12.8": {"min": 70, "max": 120}, + "12.9": {"min": 70, "max": 120}, + } - if torch.version.cuda is not None: # on ROCm we don't want this check - CUDA_VERSION = torch._C._cuda_getCompiledVersion() # noqa: F841 + if ( + torch.version.cuda is not None and torch.cuda.get_arch_list() + ): # on ROCm we don't want this check for d in range(device_count()): capability = get_device_capability(d) major = capability[0] @@ -267,13 +273,35 @@ def _check_capability(): current_arch = major * 10 + minor min_arch = min( (_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()), - default=35, + default=50, + ) + max_arch = max( + (_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()), + default=50, ) - if current_arch < min_arch: + if current_arch < min_arch or current_arch > max_arch: warnings.warn( - old_gpu_warn - % (d, name, major, minor, min_arch // 10, min_arch % 10) + incompatible_gpu_warn + % ( + d, + name, + major, + minor, + min_arch // 10, + min_arch % 10, + max_arch // 10, + max_arch % 10, + ) ) + matched_arches = "" + for arch, arch_info in CUDA_ARCHES_SUPPORTED.items(): + if ( + current_arch >= arch_info["min"] + and current_arch <= arch_info["max"] + ): + matched_arches += f" {arch}" + if matched_arches != "": + warnings.warn(matched_cuda_warn.format(matched_arches)) def _check_cubins(): @@ -379,8 +407,6 @@ def _lazy_init(): ) # This function throws if there's a driver initialization error, no GPUs # are found or any other error occurs - if "CUDA_MODULE_LOADING" not in os.environ: - os.environ["CUDA_MODULE_LOADING"] = "LAZY" torch._C._cuda_init() # Some of the queued calls may reentrantly call _lazy_init(); # we need to just return without initializing in that case. @@ -1777,6 +1803,9 @@ def _compile_kernel( from . import amp, jiterator, nvtx, profiler, sparse, tunable +_POOL_HANDLE = NewType("_POOL_HANDLE", tuple[int, int]) + + __all__ = [ # Typed storage and tensors "BFloat16Storage", diff --git a/torch/cuda/graphs.py b/torch/cuda/graphs.py index b58a7808593d..b1d1e4f8c478 100644 --- a/torch/cuda/graphs.py +++ b/torch/cuda/graphs.py @@ -1,12 +1,34 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + import gc import typing +from typing import Callable, Optional, overload, TYPE_CHECKING, Union +from typing_extensions import ParamSpec, Self, TypeAlias, TypeVar import torch +from torch import Tensor + + +if TYPE_CHECKING: + # importing _POOL_HANDLE at runtime toplevel causes an import cycle + from torch.cuda import _POOL_HANDLE from .._utils import _dummy_type +__all__ = [ + "is_current_stream_capturing", + "graph_pool_handle", + "CUDAGraph", + "graph", + "make_graphed_callables", +] + + +_R = TypeVar("_R") +_P = ParamSpec("_P") + + if not hasattr(torch._C, "_CudaStreamBase"): # Define dummy base classes torch._C.__dict__["_CUDAGraph"] = _dummy_type("_CUDAGraph") @@ -22,7 +44,7 @@ ) -def is_current_stream_capturing(): +def is_current_stream_capturing() -> bool: r"""Return True if CUDA graph capture is underway on the current CUDA stream, False otherwise. If a CUDA context does not exist on the current device, returns False without initializing the context. @@ -31,7 +53,7 @@ def is_current_stream_capturing(): # Python shim helps Sphinx process docstrings more reliably. -def graph_pool_handle(): +def graph_pool_handle() -> _POOL_HANDLE: r"""Return an opaque token representing the id of a graph memory pool. See :ref:`Graph memory management`. @@ -39,7 +61,7 @@ def graph_pool_handle(): .. warning:: This API is in beta and may change in future releases. """ - return _graph_pool_handle() + return torch.cuda._POOL_HANDLE(_graph_pool_handle()) # Python shim helps Sphinx process docstrings more reliably. @@ -70,10 +92,12 @@ class CUDAGraph(torch._C._CUDAGraph): """ - def __new__(cls, keep_graph=False): + def __new__(cls, keep_graph: bool = False) -> Self: return super().__new__(cls, keep_graph) - def capture_begin(self, pool=None, capture_error_mode="global"): + def capture_begin( + self, pool: Optional[_POOL_HANDLE] = None, capture_error_mode: str = "global" + ) -> None: r"""Begin capturing CUDA work on the current stream. Typically, you shouldn't call ``capture_begin`` yourself. @@ -92,7 +116,7 @@ def capture_begin(self, pool=None, capture_error_mode="global"): """ # noqa: B950 super().capture_begin(pool=pool, capture_error_mode=capture_error_mode) - def capture_end(self): + def capture_end(self) -> None: r"""End CUDA graph capture on the current stream. After ``capture_end``, ``replay`` may be called on this instance. @@ -103,7 +127,7 @@ def capture_end(self): """ super().capture_end() - def instantiate(self): + def instantiate(self) -> None: r"""Instantiate the CUDA graph. Will be called by ``capture_end`` if ``keep_graph=False``, or by ``replay`` if ``keep_graph=True`` and ``instantiate`` has not already been @@ -112,15 +136,15 @@ def instantiate(self): """ super().instantiate() - def replay(self): + def replay(self) -> None: r"""Replay the CUDA work captured by this graph.""" super().replay() - def reset(self): + def reset(self) -> None: r"""Delete the graph currently held by this instance.""" super().reset() - def pool(self): + def pool(self) -> _POOL_HANDLE: r"""Return an opaque token representing the id of this graph's memory pool. This id can optionally be passed to another graph's ``capture_begin``, @@ -128,11 +152,11 @@ def pool(self): """ return super().pool() - def enable_debug_mode(self): + def enable_debug_mode(self) -> None: r"""Enable debugging mode for CUDAGraph.debug_dump.""" return super().enable_debug_mode() - def debug_dump(self, debug_path): + def debug_dump(self, debug_path: str) -> None: r""" Arguments: debug_path (required): Path to dump the graph to. @@ -142,7 +166,7 @@ def debug_dump(self, debug_path): """ return super().debug_dump(debug_path) - def raw_cuda_graph(self): + def raw_cuda_graph(self) -> int: r"""Returns the underlying cudaGraph_t. ``keep_graph`` must be True. See the following for APIs for how to manipulate this object: `Graph Managmement `_ and `cuda-python Graph Management bindings `_ @@ -180,13 +204,13 @@ class graph: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 """ # noqa: B950 - default_capture_stream: typing.Optional["torch.cuda.Stream"] = None + default_capture_stream: Optional[torch.cuda.Stream] = None def __init__( self, - cuda_graph, - pool=None, - stream=None, + cuda_graph: CUDAGraph, + pool: Optional[_POOL_HANDLE] = None, + stream: Optional[torch.cuda.Stream] = None, capture_error_mode: str = "global", ): # Lazy-init of default_capture_stream helps avoid circular-import errors. @@ -195,7 +219,9 @@ def __init__( if self.__class__.default_capture_stream is None: self.__class__.default_capture_stream = torch.cuda.Stream() - self.pool = () if pool is None else (pool,) + self.pool: Union[tuple[()], tuple[_POOL_HANDLE]] = ( + () if pool is None else (pool,) + ) self.capture_stream = ( stream if stream is not None else self.__class__.default_capture_stream ) @@ -204,7 +230,7 @@ def __init__( self.cuda_graph = cuda_graph self.capture_error_mode = capture_error_mode - def __enter__(self): + def __enter__(self) -> None: # Free as much memory as we can for the graph torch.cuda.synchronize() gc.collect() @@ -215,18 +241,47 @@ def __enter__(self): self.stream_ctx.__enter__() self.cuda_graph.capture_begin( - *self.pool, capture_error_mode=self.capture_error_mode + # type: ignore[misc] + *self.pool, + capture_error_mode=self.capture_error_mode, ) - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, *args: object) -> None: self.cuda_graph.capture_end() - self.stream_ctx.__exit__(exc_type, exc_value, traceback) + self.stream_ctx.__exit__(*args) # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__() +_ModuleOrCallable: TypeAlias = Union["torch.nn.Module", Callable[..., object]] + + +@overload def make_graphed_callables( - callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None -): + callables: _ModuleOrCallable, + sample_args: tuple[Tensor, ...], + num_warmup_iters: int = 3, + allow_unused_input: bool = False, + pool: Optional[_POOL_HANDLE] = None, +) -> _ModuleOrCallable: ... + + +@overload +def make_graphed_callables( + callables: tuple[_ModuleOrCallable, ...], + sample_args: tuple[tuple[Tensor, ...], ...], + num_warmup_iters: int = 3, + allow_unused_input: bool = False, + pool: Optional[_POOL_HANDLE] = None, +) -> tuple[_ModuleOrCallable, ...]: ... + + +def make_graphed_callables( + callables: Union[_ModuleOrCallable, tuple[_ModuleOrCallable, ...]], + sample_args: Union[tuple[Tensor, ...], tuple[tuple[Tensor, ...], ...]], + num_warmup_iters: int = 3, + allow_unused_input: bool = False, + pool: Optional[_POOL_HANDLE] = None, +) -> Union[_ModuleOrCallable, tuple[_ModuleOrCallable, ...]]: r"""Accept callables (functions or :class:`nn.Module`\ s) and returns graphed versions. Each graphed callable's forward pass runs its source callable's @@ -300,14 +355,17 @@ def make_graphed_callables( just_one_callable = False + _sample_args: tuple[tuple[Tensor, ...], ...] if not isinstance(callables, tuple): just_one_callable = True callables = (callables,) - sample_args = (sample_args,) + _sample_args = (typing.cast(tuple[Tensor, ...], sample_args),) + else: + _sample_args = typing.cast(tuple[tuple[Tensor, ...], ...], sample_args) flatten_sample_args = [] - for c, args in zip(callables, sample_args): + for c, args in zip(callables, _sample_args): if isinstance(c, torch.nn.Module): assert ( len(c._backward_hooks) == 0 @@ -352,7 +410,7 @@ def make_graphed_callables( torch.cuda.synchronize() with torch.cuda.stream(torch.cuda.Stream()): for func, args, static_input_surface in zip( - callables, sample_args, per_callable_static_input_surfaces + callables, _sample_args, per_callable_static_input_surfaces ): grad_inputs, outputs, outputs_grad = None, None, None for _ in range(num_warmup_iters): @@ -382,11 +440,11 @@ def make_graphed_callables( # Capture forward graphs per_callable_static_outputs = [] per_callable_output_unflatten_spec = [] - for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs): + for func, args, fwd_graph in zip(callables, _sample_args, fwd_graphs): with torch.cuda.graph(fwd_graph, pool=mempool): - outputs = func(*args) + func_outputs = func(*args) - flatten_outputs, spec = torch.utils._pytree.tree_flatten(outputs) + flatten_outputs, spec = torch.utils._pytree.tree_flatten(func_outputs) per_callable_static_outputs.append(tuple(flatten_outputs)) per_callable_output_unflatten_spec.append(spec) @@ -438,19 +496,19 @@ def make_graphed_callables( # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable. def make_graphed_autograd_function( - fwd_graph, - bwd_graph, - module_params, - len_user_args, - output_unflatten_spec, - static_input_surface, - static_outputs, - static_grad_outputs, - static_grad_inputs, - ): + fwd_graph: CUDAGraph, + bwd_graph: CUDAGraph, + module_params: tuple[torch.nn.Parameter, ...], + len_user_args: int, + output_unflatten_spec: torch.utils._pytree.TreeSpec, + static_input_surface: tuple[Tensor, ...], + static_outputs: tuple[Tensor, ...], + static_grad_outputs: tuple[Optional[Tensor], ...], + static_grad_inputs: tuple[Tensor, ...], + ) -> Callable[..., object]: class Graphed(torch.autograd.Function): @staticmethod - def forward(ctx, *inputs): + def forward(ctx: object, *inputs: Tensor) -> tuple[Tensor, ...]: # At this stage, only the user args may (potentially) be new tensors. for i in range(len_user_args): if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): @@ -461,7 +519,7 @@ def forward(ctx, *inputs): @staticmethod @torch.autograd.function.once_differentiable - def backward(ctx, *grads): + def backward(ctx: object, *grads: Tensor) -> tuple[Tensor, ...]: assert len(grads) == len(static_grad_outputs) for g, grad in zip(static_grad_outputs, grads): if g is not None: @@ -477,7 +535,7 @@ def backward(ctx, *grads): b.detach() if b is not None else b for b in static_grad_inputs ) - def functionalized(*user_args): + def functionalized(*user_args: object) -> object: # Runs the autograd function with inputs == all inputs to the graph that might require grad # (explicit user args + module parameters) # Assumes module params didn't change since capture. @@ -488,7 +546,7 @@ def functionalized(*user_args): return functionalized # Put together the final graphed callables - ret = [] + ret: list[_ModuleOrCallable] = [] for i, func in enumerate(callables): graphed = make_graphed_autograd_function( fwd_graphs[i], @@ -504,20 +562,25 @@ def functionalized(*user_args): if isinstance(func, torch.nn.Module): - def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): - def new_fwd(*user_args): + def make_graphed_forward( + func: torch.nn.Module, + graph_training_state: bool, + graphed: Callable[_P, _R], + orig_fwd: Callable[_P, _R], + ) -> Callable[_P, _R]: + def new_fwd(*user_args: _P.args, **user_kwargs: _P.kwargs) -> _R: # If the module's training-or-eval state matches what we graphed, # run the graph, otherwise run the original forward method if func.training == graph_training_state: - return graphed(*user_args) + return graphed(*user_args, **user_kwargs) else: - return orig_fwd(*user_args) + return orig_fwd(*user_args, **user_kwargs) return new_fwd func.forward = make_graphed_forward( func, func.training, graphed, func.forward - ) # type: ignore[assignment] + ) ret.append(func) else: ret.append(graphed) diff --git a/torch/distributed/CONTRIBUTING.md b/torch/distributed/CONTRIBUTING.md index bcfe8df9abd4..913017a6cabf 100644 --- a/torch/distributed/CONTRIBUTING.md +++ b/torch/distributed/CONTRIBUTING.md @@ -2,7 +2,7 @@ Please go through PyTorch's top level [Contributing Guide](../../CONTRIBUTING.md) before proceeding with this guide. -[PyTorch Distributed Overview](https://pytorch.org/tutorials//beginner/dist_overview.html) is a great starting point with a lot of tutorials, documentation and design docs covering PyTorch Distributed. We would highly recommend going through some of that material before you start working on PyTorch Distributed. +[PyTorch Distributed Overview](https://pytorch.org/tutorials//beginner/dist_overview.html) is a great starting point with a lot of tutorials, documentation and design docs covering PyTorch Distributed. We highly recommend going through some of that material before you start working on PyTorch Distributed. In this document, we mostly focus on some of the code structure for PyTorch distributed and implementation details. diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index b6ba9919ee84..38e2fdbee803 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -4,6 +4,7 @@ import sys import traceback import typing +from datetime import timedelta import torch @@ -82,7 +83,7 @@ def interaction(self, *args, **kwargs): _breakpoint_cache: dict[int, typing.Any] = {} - def breakpoint(rank: int = 0, skip: int = 0): + def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600): """ Set a breakpoint, but only on a single rank. All other ranks will wait for you to be done with the breakpoint before continuing. @@ -99,6 +100,13 @@ def breakpoint(rank: int = 0, skip: int = 0): log.warning("Skip the breakpoint, counter=%d", counter) return + # avoid having the default timeout (if short) interrupt your debug session + if timeout_s is not None: + for group in torch.distributed.distributed_c10d._pg_map: + torch.distributed.distributed_c10d._set_pg_timeout( + timedelta(seconds=timeout_s), group + ) + if get_rank() == rank: pdb = _DistributedPdb() pdb.message( diff --git a/torch/distributed/_dist2.py b/torch/distributed/_dist2.py index 8704b9015997..7ef73a27a6dd 100644 --- a/torch/distributed/_dist2.py +++ b/torch/distributed/_dist2.py @@ -16,7 +16,6 @@ from torch._C._distributed_c10d import ( _current_process_group, _set_process_group, - Backend, ProcessGroup, ReduceOp, Store, @@ -47,7 +46,7 @@ def __call__( world_size: int, timeout: timedelta, device: torch.device, - pg_options: Backend.Options, + **kwargs: object, ) -> ProcessGroup: ... @@ -71,11 +70,11 @@ def _gloo_factory( world_size: int, timeout: timedelta, device: torch.device, - pg_options: Backend.Options, + **kwargs: object, ) -> ProcessGroup: from torch.distributed import ProcessGroupGloo - assert pg_options is None, "Gloo backend does not support options" + assert len(kwargs) == 0, "Gloo backend received unexpected kwargs" backend_class = ProcessGroupGloo(store, rank, world_size, timeout) backend_class._set_sequence_number_for_group() @@ -101,15 +100,18 @@ def _nccl_factory( world_size: int, timeout: timedelta, device: torch.device, - pg_options: Backend.Options, + **kwargs: object, ) -> ProcessGroup: from torch.distributed import ProcessGroupNCCL - assert isinstance(pg_options, ProcessGroupNCCL.Options) + opts = ProcessGroupNCCL.Options() + opts._timeout = timeout + for k, v in kwargs.items(): + if not hasattr(opts, k): + raise KeyError(f"Unknown option {k}") + setattr(opts, k, v) - pg_options._timeout = timeout - - backend_class = ProcessGroupNCCL(store, rank, world_size, pg_options) + backend_class = ProcessGroupNCCL(store, rank, world_size, opts) backend_class._set_sequence_number_for_group() backend_class.eager_connect_single_device(device) @@ -128,7 +130,7 @@ def new_group( backend: str, timeout: timedelta, device: Union[str, torch.device], - pg_options: Backend.Options, + **kwargs: object, ) -> ProcessGroup: """ Create a new process group with the given backend and options. This group is @@ -139,7 +141,8 @@ def new_group( backend: The backend to use for the process group. timeout: The timeout for collective operations. device: The device to use for the process group. - pg_options: The options to use for the process group. + **kwargs: All remaining arguments are passed to the backend constructor. + See the backend specific documentation for details. Returns: A new process group. @@ -152,7 +155,7 @@ def new_group( store, rank, world_size = next(iter(rendezvous("env://"))) store.set_timeout(timeout) - return _BACKENDS[backend](store, rank, world_size, timeout, device, pg_options) + return _BACKENDS[backend](store, rank, world_size, timeout, device, **kwargs) def current_process_group() -> ProcessGroup: diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index ec51b2b7a181..0ffae8a9c9fe 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -821,6 +821,10 @@ def _are_we_tracing() -> bool: # If fake mode is turned on, we are almost definitely compiling/tracing. if torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is not None: return True + + if torch._dynamo.compiled_autograd.in_compiled_autograd_initial_trace: + return True + return get_proxy_mode() is not None diff --git a/torch/distributed/_symmetric_memory/__init__.py b/torch/distributed/_symmetric_memory/__init__.py index 634e953aeb36..b45b902406ea 100644 --- a/torch/distributed/_symmetric_memory/__init__.py +++ b/torch/distributed/_symmetric_memory/__init__.py @@ -47,10 +47,11 @@ def enable_symm_mem_for_group(group_name: str) -> None: _is_test_mode: bool = False +_mocked_group_names: Optional[set[str]] = None @contextmanager -def _test_mode() -> Generator[None, None, None]: +def _test_mode(group_names: Optional[set[str]] = None) -> Generator[None, None, None]: """ Forces ``is_symm_mem_enabled_for_group()`` to return ``True`` and the ops defined in the ``symm_mem`` namespace to use fallback implementations. @@ -58,12 +59,16 @@ def _test_mode() -> Generator[None, None, None]: The context manager is not thread safe. """ global _is_test_mode + global _mocked_group_names prev = _is_test_mode + prev_group_names = _mocked_group_names try: _is_test_mode = True + _mocked_group_names = group_names yield finally: _is_test_mode = prev + _mocked_group_names = prev_group_names def is_symm_mem_enabled_for_group(group_name: str) -> bool: @@ -73,7 +78,9 @@ def is_symm_mem_enabled_for_group(group_name: str) -> bool: Args: group_name (str): the name of the process group. """ - return _is_test_mode or group_name in _group_name_to_store + if _is_test_mode: + return _mocked_group_names is None or group_name in _mocked_group_names + return group_name in _group_name_to_store _group_name_to_workspace_tensor: dict[str, Optional[torch.Tensor]] = {} diff --git a/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/torch/distributed/_symmetric_memory/_nvshmem_triton.py index 75abae38c755..dda1885a8e16 100644 --- a/torch/distributed/_symmetric_memory/_nvshmem_triton.py +++ b/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -197,3 +197,86 @@ def quiet(_builder=None): # type: ignore[no-untyped-def] is_pure=False, _builder=_builder, ) + + @core.extern + def my_pe(_builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [], + {(): ("nvshmem_my_pe", core.dtype("int32"))}, + is_pure=True, + _builder=_builder, + ) + + @core.extern + def n_pes(_builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [], + {(): ("nvshmem_n_pes", core.dtype("int32"))}, + is_pure=True, + _builder=_builder, + ) + + @core.extern + def barrier_all(_builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [], + {(): ("nvshmem_barrier_all", core.dtype("int32"))}, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def sync_all(_builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [], + {(): ("nvshmem_sync_all", core.dtype("int32"))}, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def alltoall(team, dest, source, nelems, _builder=None): # type: ignore[no-untyped-def] + """Perform alltoall operation on NVSHMEM symmetric memory""" + return core.extern_elementwise( + "", + "", + [team, dest, source, nelems], + { + ( + core.dtype("int64"), # team handle + core.dtype("int64"), # dest ptr + core.dtype("int64"), # source ptr + core.dtype("int64"), # nelems + ): ("nvshmem_longlong_alltoall", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def broadcast(team, dest, source, nelems, pe_root, _builder=None): # type: ignore[no-untyped-def] + """Broadcasts data from a root PE to all other PEs in a team""" + return core.extern_elementwise( + "", + "", + [team, dest, source, nelems, pe_root], + { + ( + core.dtype("int64"), # team handle + core.dtype("int64"), # dest ptr + core.dtype("int64"), # source ptr + core.dtype("int64"), # nelems + core.dtype("int64"), # pe_root + ): ("nvshmem_longlong_broadcast", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) diff --git a/torch/distributed/checkpoint/_async_thread_executor.py b/torch/distributed/checkpoint/_async_thread_executor.py index 1038c177529d..3fad17b2dea9 100644 --- a/torch/distributed/checkpoint/_async_thread_executor.py +++ b/torch/distributed/checkpoint/_async_thread_executor.py @@ -37,7 +37,9 @@ def save_wrapper( class _ThreadBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor): def __init__(self) -> None: - self._executor = ThreadPoolExecutor(max_workers=1) + self._executor = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="AsyncCheckpointExecutor" + ) def execute_save( self, diff --git a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py index 86630903a951..dc988e999c4e 100644 --- a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py +++ b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py @@ -595,35 +595,6 @@ def _write_sub_tensor_to_file_optimized( ) return - # Check for fully contiguous chunk - expected_chunk_size = math.prod(sub_tensor_shape) * element_size - - if len(sub_tensor_bytes) == expected_chunk_size: - # Calculate if the chunk maps to a contiguous region in the tensor - tensor_strides = [1] - for i in range(len(tensor_shape) - 1, 0, -1): - tensor_strides.insert(0, tensor_strides[0] * tensor_shape[i]) - - # Check if chunk represents a contiguous slice - chunk_start_pos = sum( - offset * stride - for offset, stride in zip(sub_tensor_offsets, tensor_strides) - ) - - # For simple contiguous cases, use direct copy - if all( - offset + size <= dim - for offset, size, dim in zip( - sub_tensor_offsets, sub_tensor_shape, tensor_shape - ) - ): - tensor_start_byte = output_start_byte + chunk_start_pos * element_size - - with fs.open(output_file_path, "r+b") as out_f: - out_f.seek(tensor_start_byte) - out_f.write(sub_tensor_bytes) - return - # Fall back to the original implementation for complex patterns _write_sub_tensor_to_file( fs, diff --git a/torch/distributed/checkpoint/_experimental/staging.py b/torch/distributed/checkpoint/_experimental/staging.py index 55e4c15921a2..3e57650b5880 100644 --- a/torch/distributed/checkpoint/_experimental/staging.py +++ b/torch/distributed/checkpoint/_experimental/staging.py @@ -82,7 +82,7 @@ class CheckpointStagerConfig: use_async_staging (bool): Enable asynchronous staging using a background thread pool. Allows overlapping computation with staging operations. Requires CUDA. Default: True - use_cuda_non_blocking_copy (bool): Use non-blocking CUDA memory + use_non_blocking_copy (bool): Use non-blocking device memory copies with stream synchronization. Improves performance by allowing CPU work to continue during GPU transfers. Default: True @@ -93,7 +93,7 @@ class CheckpointStagerConfig: use_pinned_memory: bool = True use_shared_memory: bool = True use_async_staging: bool = True - use_cuda_non_blocking_copy: bool = True + use_non_blocking_copy: bool = True class DefaultStager(CheckpointStager): @@ -153,15 +153,17 @@ def __init__( if self._config.use_async_staging: self._staging_executor = ThreadPoolExecutor(max_workers=1) - if torch.cuda.is_available(): + if torch.accelerator.is_available(): # Note: stream needs to be initialized on the main thread after default cuda # stream is setup/used to avoid the risk of accidentally reusing the main # compute stream or in other cases kernels actually launching from the # main thread. - self._staging_stream = torch.cuda.Stream() + self._staging_stream = torch.Stream() - if self._config.use_cuda_non_blocking_copy: - assert torch.cuda.is_available(), "Non-blocking copy requires CUDA" + if self._config.use_non_blocking_copy: + assert torch.accelerator.is_available(), ( + "Non-blocking copy requires that the current accelerator is available." + ) def stage( self, @@ -182,16 +184,16 @@ def stage( def _stage(self, state_dict: STATE_DICT, **kwargs: Any) -> STATE_DICT: state_dict = self._state_dict_stager.stage( - state_dict, non_blocking=self._config.use_cuda_non_blocking_copy, **kwargs + state_dict, non_blocking=self._config.use_non_blocking_copy, **kwargs ) - if self._config.use_cuda_non_blocking_copy: + if self._config.use_non_blocking_copy: assert self._staging_stream or not self._config.use_async_staging, ( - "Non-blocking cuda copy in a background thread for async staging needs staging_stream to be initialized." + "Non-blocking copy in a background thread for async staging needs staging_stream to be initialized." ) # waits for the enqued copy operations to finish. - self._staging_stream.synchronize() if self._staging_stream else torch.cuda.synchronize() + self._staging_stream.synchronize() if self._staging_stream else torch.accelerator.synchronize() return state_dict diff --git a/torch/distributed/checkpoint/_pg_transport.py b/torch/distributed/checkpoint/_pg_transport.py index cab908b5a851..f4c53829b23b 100644 --- a/torch/distributed/checkpoint/_pg_transport.py +++ b/torch/distributed/checkpoint/_pg_transport.py @@ -9,6 +9,12 @@ import torch from torch.distributed import ProcessGroup, Work +from torch.distributed._shard.sharded_tensor import ( + Shard as ShardedTensorShard, + ShardedTensor, + ShardMetadata, +) +from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata from torch.distributed.tensor import _DTensorSpec, DTensor from torch.utils._pytree import ( KeyPath, @@ -53,6 +59,22 @@ class _DTensorMeta: spec: _DTensorSpec +@dataclass +class _ShardedTensorMeta: + """ + This is the metadata for a ShardedTensor that is used to transfer checkpoints. + It contains the metadata for all local shards and the global tensor metadata. + + This must be pickleable so that it can be sent over the wire. + """ + + local_shards_meta: list[_TensorMeta] + local_shards_shard_metadata: list[ + ShardMetadata + ] # Original shard metadata for each local shard + sharded_tensor_metadata: ShardedTensorMetadata + + @dataclass class _StateDictMeta: """ @@ -72,7 +94,9 @@ class _StateDictMeta: treespec: TreeSpec paths: list[KeyPath] - non_tensor_leaves: list[Union[object, _TensorMeta, _DTensorMeta]] + non_tensor_leaves: list[ + Union[object, _TensorMeta, _DTensorMeta, _ShardedTensorMeta] + ] @contextmanager @@ -104,7 +128,9 @@ def _prepare_state_dict( leaves, treespec = tree_flatten_with_path(state_dict) paths: list[KeyPath] = [] - non_tensor_leaves: list[Union[object, _TensorMeta, _DTensorMeta]] = [] + non_tensor_leaves: list[ + Union[object, _TensorMeta, _DTensorMeta, _ShardedTensorMeta] + ] = [] tensors: list[torch.Tensor] = [] for key_path, v in leaves: paths.append(key_path) @@ -120,6 +146,26 @@ def _prepare_state_dict( spec=v._spec, ) ) + elif isinstance(v, ShardedTensor): + # Handle ShardedTensor by extracting all local shards + local_shards = v.local_shards() + + # Prepare metadata for all local shards + local_shards_meta = [] + local_shards_shard_metadata = [] + for shard in local_shards: + tensor, tensor_meta = _prepare_tensor(shard.tensor) + tensors.append(tensor) + local_shards_meta.append(tensor_meta) + local_shards_shard_metadata.append(shard.metadata) + + non_tensor_leaves.append( + _ShardedTensorMeta( + local_shards_meta=local_shards_meta, + local_shards_shard_metadata=local_shards_shard_metadata, + sharded_tensor_metadata=v.metadata(), # Complete metadata + ) + ) elif isinstance(v, torch.Tensor): tensor, tensor_meta = _prepare_tensor(v) tensors.append(tensor) @@ -242,7 +288,6 @@ def recv_checkpoint(self, src_rank: int) -> object: Returns: The reconstructed state dictionary with model parameters """ - state_dict = self._state_dict() if self._state_dict else {} state_dict_leaves, _ = tree_flatten_with_path(state_dict) @@ -301,6 +346,37 @@ def recv(path: KeyPath, v: _TensorMeta) -> torch.Tensor: elif isinstance(v, _DTensorMeta): tensor = recv(path, v.local) values.append(DTensor(tensor, v.spec, requires_grad=False)) + elif isinstance(v, _ShardedTensorMeta): + # Receive all local shards that were sent to us + local_shards = [] + current_rank = self._pg.rank() + + # Receive tensors for each local shard that was sent + for j, shard_meta in enumerate(v.local_shards_meta): + tensor = recv(path, shard_meta) + + # Use the original shard metadata that was stored during preparation + # but update the placement to reflect the current rank/device + original_shard_metadata = v.local_shards_shard_metadata[j] + updated_shard_metadata = ShardMetadata( + shard_offsets=original_shard_metadata.shard_offsets, + shard_sizes=original_shard_metadata.shard_sizes, + placement=f"rank:{current_rank}/{tensor.device.type}", + ) + + local_shard = ShardedTensorShard( + tensor=tensor, metadata=updated_shard_metadata + ) + local_shards.append(local_shard) + + # Use complete metadata to reconstruct ShardedTensor + sharded_tensor = ( + ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=local_shards, + sharded_tensor_metadata=v.sharded_tensor_metadata, + ) + ) + values.append(sharded_tensor) else: values.append(v) diff --git a/torch/distributed/checkpoint/hf_storage.py b/torch/distributed/checkpoint/hf_storage.py index 0da3fc089f88..13fd61910dd2 100644 --- a/torch/distributed/checkpoint/hf_storage.py +++ b/torch/distributed/checkpoint/hf_storage.py @@ -67,6 +67,7 @@ def __init__( token: Optional[str] = None, save_distributed: bool = False, enable_consolidation: bool = False, + consolidated_output_path: Optional[str] = None, thread_count_consolidation: int = 1, ) -> None: """ diff --git a/torch/distributed/checkpoint/staging.py b/torch/distributed/checkpoint/staging.py index a2093f803ee6..9e1031c7fdda 100644 --- a/torch/distributed/checkpoint/staging.py +++ b/torch/distributed/checkpoint/staging.py @@ -110,7 +110,7 @@ class StagingOptions: use_async_staging (bool): Enable asynchronous staging using a background thread pool. Allows overlapping computation with staging operations. Requires CUDA. Default: True - use_cuda_non_blocking_copy (bool): Use non-blocking CUDA memory + use_non_blocking_copy (bool): Use non-blocking device memory copies with stream synchronization. Improves performance by allowing CPU work to continue during GPU transfers. Default: True @@ -121,7 +121,7 @@ class StagingOptions: use_pinned_memory: bool = True use_shared_memory: bool = True use_async_staging: bool = True - use_cuda_non_blocking_copy: bool = True + use_non_blocking_copy: bool = True class DefaultStager(AsyncStager): @@ -177,15 +177,17 @@ def __init__( self._staging_stream = None if self._config.use_async_staging: self._staging_executor = ThreadPoolExecutor(max_workers=1) - if torch.cuda.is_available(): + if torch.accelerator.is_available(): # Note: stream needs to be initialized on the main thread after default cuda # stream is setup/used to avoid the risk of accidentally reusing the main # compute stream or in other cases kernels actually launching from the # main thread. - self._staging_stream = torch.cuda.Stream() + self._staging_stream = torch.Stream() - if self._config.use_cuda_non_blocking_copy: - assert torch.cuda.is_available(), "Non-blocking copy requires CUDA" + if self._config.use_non_blocking_copy: + assert torch.accelerator.is_available(), ( + "Non-blocking copy requires that the current accelerator is available." + ) self._staging_future: Optional[Future[STATE_DICT_TYPE]] = None @@ -216,9 +218,9 @@ def stage( return self._stage(state_dict, **kwargs) def _stage(self, state_dict: STATE_DICT_TYPE, **kwargs: Any) -> STATE_DICT_TYPE: - if self._config.use_cuda_non_blocking_copy: + if self._config.use_non_blocking_copy: assert self._staging_stream or not self._config.use_async_staging, ( - "Non-blocking cuda copy in a background thread for async staging needs staging_stream to be initialized." + "Non-blocking copy in a background thread for async staging needs staging_stream to be initialized." ) with ( self._staging_stream @@ -226,10 +228,10 @@ def _stage(self, state_dict: STATE_DICT_TYPE, **kwargs: Any) -> STATE_DICT_TYPE: else nullcontext() ): state_dict = self._state_dict_stager.stage( - state_dict, non_blocking=self._config.use_cuda_non_blocking_copy + state_dict, non_blocking=self._config.use_non_blocking_copy ) # waits for the enqued copy operations to finish. - self._staging_stream.synchronize() if self._staging_stream else torch.cuda.synchronize() + self._staging_stream.synchronize() if self._staging_stream else torch.accelerator.synchronize() else: state_dict = self._state_dict_stager.stage(state_dict, non_blocking=False) return state_dict diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index dfd07c707a7f..370bab11b4db 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -644,11 +644,15 @@ def __exit__(self, exc_type, exc_value, exc_traceback) -> None: def __repr__(self) -> str: device_mesh_repr = ( - f"DeviceMesh('{self.device_type}', {self.mesh.tolist()})" - if not self.mesh_dim_names - else f"DeviceMesh('{self.device_type}', {self.mesh.tolist()}, mesh_dim_names={self.mesh_dim_names})" + f"({', '.join(f'{k}={v}' for k, v in zip(self.mesh_dim_names, self.mesh.shape))})" + if self.mesh_dim_names + else f"{tuple(self.mesh.shape)}" ) - return device_mesh_repr + device_mesh_repr = f"DeviceMesh({device_mesh_repr}, device: '{self.device_type}', stride: {self.mesh.stride()}" + # We only print the mesh tensor if the debug mode is turned on. + if os.environ.get("TORCH_DISTRIBUTED_DEBUG", "") == "DETAIL": + device_mesh_repr += f", Mesh: {self.mesh.tolist()}" + return f"{device_mesh_repr})" def __hash__(self): # lazily compute hash @@ -1007,7 +1011,7 @@ def init_device_mesh( required for distributed communications behind the scene. Args: - device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like". + device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like", "xpu". Passing in a device type with a GPU index, such as "cuda:0", is not allowed. mesh_shape (Tuple[int]): A tuple defining the dimensions of the multi-dimensional array describing the layout of devices. diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 2bc99a51cd64..d96cc61a5ac7 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1568,12 +1568,12 @@ def init_process_group( Args: backend (str or Backend, optional): The backend to use. Depending on build-time configurations, valid values include ``mpi``, ``gloo``, - ``nccl``, ``ucc``, or one that is registered by a third-party + ``nccl``, ``ucc``, ``xccl`` or one that is registered by a third-party plugin. Since 2.6, if ``backend`` is not provided, c10d will use a backend registered for the device type indicated by the `device_id` kwarg (if provided). The known default registrations today are: ``nccl`` - for ``cuda``, ``gloo`` for ``cpu``. + for ``cuda``, ``gloo`` for ``cpu``, ``xccl`` for ``xpu``. If neither ``backend`` nor ``device_id`` is provided, c10d will detect the accelerator on the run-time machine and use a backend registered for that detected accelerator (or ``cpu``). diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py index 855a706e6d30..7649c32ec1c0 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py @@ -292,21 +292,22 @@ def _init_sharded_param( dp_global_mesh is None or tp_global_mesh is None ): raise AssertionError( - "FSDP requires the DP and TP mesh to have the same parent mesh but got: \n" - f"DP's global mesh: {dp_global_mesh}\nTP's global mesh: {tp_global_mesh}" + "FSDP requires the DP and model parallel TP/EP mesh to have the same parent mesh but got: \n" + f"DP's global mesh: {dp_global_mesh}\nTP/EP's global mesh: {tp_global_mesh}" ) name_dims_error = "FSDP requires named DeviceMesh dims for ND parallelism" assert dp_mesh.mesh_dim_names is not None, name_dims_error assert tp_mesh.mesh_dim_names is not None, name_dims_error submesh_names = dp_mesh.mesh_dim_names + tp_mesh.mesh_dim_names self._spmd_mesh = dp_global_mesh[submesh_names] - if len(self._tp_spec.placements) != 1: + if len(self._tp_spec.placements) > 2: raise NotImplementedError( - f"FSDP only supports 1D TP, not {self._tp_spec.placements}" + f"FSDP only supports 1D TP/EP or 2D EP+TP, not {self._tp_spec.placements}" ) split_factor = self._tp_spec.num_shards_map[shard_dim] - assert 2 <= self._spmd_mesh.ndim <= 3, ( - f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}." + assert 2 <= self._spmd_mesh.ndim <= 4, ( + "_spmd_mesh.ndim can only be 2 (FSDP+TP/EP), 3 (FSDP+EP+TP, HSDP+TP/EP), " + f"or 4 (HSDP+EP+TP) but got {self._spmd_mesh.ndim}." ) self._spmd_placements: tuple[Placement, ...] dp_shard_tp_placement = ( @@ -315,11 +316,11 @@ def _init_sharded_param( if split_factor > 1 else fsdp_placement ), - self._tp_spec.placements[0], + *self._tp_spec.placements, ) - if self._spmd_mesh.ndim == 2: + if dp_mesh.ndim == 1: # FSDP self._spmd_placements = dp_shard_tp_placement - else: + else: # HSDP assert self.mesh_info.replicate_mesh_dim == 0 self._spmd_placements = (Replicate(),) + dp_shard_tp_placement self._sharding_spec = DTensorSpec( diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index d60e9bd30781..f4ded5a1f0bc 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -248,13 +248,13 @@ def __init__( logger.info("Using %s", self.__class__.__name__) def _maybe_compute_loss(self, stage, output, target_mbs, mb_index): - if stage.is_last and self._has_backward: + if stage.is_last and self._loss_fn is not None: loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index] self._internal_losses.append(loss) def _maybe_get_loss(self, stage, mb_index): valid_index = 0 <= mb_index < len(self._internal_losses) - if stage.is_last and self._has_backward and valid_index: + if stage.is_last and self._loss_fn is not None and valid_index: return self._internal_losses[mb_index] elif len(self._internal_losses) != 0 and not valid_index: raise RuntimeError( @@ -319,6 +319,26 @@ def step(self, *args, target=None, losses: Optional[list] = None, **kwargs): """ raise NotImplementedError + def eval(self, *args, target=None, losses: Optional[list] = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches, calling forward only. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target values for the loss function. + losses: a list to store the losses for each microbatch. + """ + # Save the original has_backward state + original_has_backward = self._has_backward + try: + self._has_backward = False + return self.step(*args, target=target, losses=losses, **kwargs) + finally: + # Restore the original state + self._has_backward = original_has_backward + def _check_inputs( self, arg_mbs: Optional[list] = None, @@ -475,8 +495,6 @@ def __init__( # Self attributes self._stage = stage self._num_stages = stage.num_stages - # Set the same has_backward flag for stage object - self._stage.has_backward = self._has_backward self._stage_initialized = False if n_microbatches < self._num_stages: @@ -506,6 +524,15 @@ def step(self, *args, target=None, losses: Optional[list] = None, **kwargs): target: target for the loss function. losses: a list to store the losses for each microbatch. """ + if not torch.is_grad_enabled(): + raise RuntimeError( + "step() requires gradients to be enabled for backward computation; " + "it should not be used under torch.no_grad() context. " + "Please call eval() instead." + ) + + # Set the same has_backward flag for stage object + self._stage.has_backward = self._has_backward # Clean per iteration self._stage.clear_runtime_states() @@ -650,10 +677,6 @@ def _step_microbatches( for work in fwd_sends_to_wait: _wait_batch_p2p(work) - # No loss function, no need to run backward - if not self._has_backward: - return - # Run backward # Delay send waits bwd_sends_to_wait: list[list[dist.Work]] = [] @@ -681,13 +704,13 @@ def _step_microbatches( grad_scale_factor=self._n_microbatches if self.scale_grads else 1 ) - # Return losses if there is a container passed in - self._update_losses(self._stage, losses) - # Wait for all backward sends to finish for work in bwd_sends_to_wait: _wait_batch_p2p(work) + # Update losses if there is a container passed in + self._update_losses(self._stage, losses) + def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: """ Returns the pipeline order for GPipe schedule. @@ -1264,9 +1287,6 @@ def __init__( for stage in self._stages: stage.stage_index_to_group_rank = self.stage_index_to_group_rank - # Set the same has_backward flag for stage object - for stage in self._stages: - stage.has_backward = self._has_backward self._stages_initialized = False # avoid putting a reference to 'self' inside the lambda, it creates a ref cycle @@ -1349,6 +1369,17 @@ def step(self, *args, target=None, losses: Optional[list] = None, **kwargs): target: target for the loss function. losses: a list to store the losses for each microbatch. """ + if not torch.is_grad_enabled(): + raise RuntimeError( + "step() requires gradients to be enabled for backward computation; " + "it should not be used under torch.no_grad() context. " + "Please call eval() instead." + ) + + # Set the same has_backward flag for stage object + for stage in self._stages: + stage.has_backward = self._has_backward + # Clean per iteration for stage in self._stages: stage.clear_runtime_states() diff --git a/torch/distributed/pipelining/stage.py b/torch/distributed/pipelining/stage.py index df229c983209..e22799545903 100644 --- a/torch/distributed/pipelining/stage.py +++ b/torch/distributed/pipelining/stage.py @@ -462,11 +462,10 @@ def get_bwd_send_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]: """ Get the gradient send ops for current stage's backward. """ - self._check_chunk_id(bwd_chunk_id) - if not self.has_backward or self.is_first: return [] + self._check_chunk_id(bwd_chunk_id) # Create bwd send infra lazily if self.grad_send_info is None: # Send info for input grads during backward: @@ -761,6 +760,10 @@ def backward_one_chunk( last_backward is controlled by the schedule and signals synchronization of gradients across DP groups after the last backward. """ + # skip backward computation if backward is not enabled + if not self.has_backward: + return + self._check_chunk_id(bwd_chunk_id) ( diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index 19c8739e0581..b0ee136c135f 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -571,7 +571,7 @@ def full_tensor( """ Return the full tensor of this DTensor. It will perform necessary collectives to gather the local tensors from other ranks in its DeviceMesh and concatenate - them together. It's a syntatic sugar of the following code: + them together. It's a syntactic sugar of the following code: ``dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()`` @@ -1011,7 +1011,7 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def] # set default placements to replicated if not specified placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim)) - # check device_mesh againts placements + # check device_mesh against placements assert device_mesh.ndim == len(placements), ( "mesh dimension does not match the length of placements" ) diff --git a/torch/distributed/tensor/_collective_utils.py b/torch/distributed/tensor/_collective_utils.py index 01505fddd0fd..36316b2f0567 100644 --- a/torch/distributed/tensor/_collective_utils.py +++ b/torch/distributed/tensor/_collective_utils.py @@ -325,7 +325,7 @@ def redistribute_cost( NOTE: 1. Only consider communication cost here, since computation costs for redistribute - are quite trival (i.e. we only need to narrow or simple division) + are quite trivial (i.e. we only need to narrow or simple division) 2. Only consider redistribute cost on same mesh, cross mesh communication cost is not quite needed for operator strategy estimation/selection. """ diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index 83270b5a64bb..1d0f57102aae 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -434,7 +434,7 @@ def _try_replicate_spec_for_scalar_tensor( "Found a non-scalar tensor with numel=1 and ndim!=0, " "we are implicitly creating a replicated DTensor for it. " "However, please consider changing it to a scalar tensor " - "or explicitly create a DTensor under distributed enviroment." + "or explicitly create a DTensor under distributed environment." ) if tensor_arg.numel() == 1 or self._allow_implicit_replication: diff --git a/torch/distributed/tensor/_dtensor_spec.py b/torch/distributed/tensor/_dtensor_spec.py index 360f1a0ea016..eb528ee4f9af 100644 --- a/torch/distributed/tensor/_dtensor_spec.py +++ b/torch/distributed/tensor/_dtensor_spec.py @@ -240,7 +240,7 @@ def from_dim_map( if placement.is_shard(): placement = cast(Shard, placement) raise RuntimeError( - f"DeviceMesh dimension cann't be mapped to two dimension of the same tensor: {i} and {placement.dim}" + f"DeviceMesh dimension can't be mapped to two dimension of the same tensor: {i} and {placement.dim}" ) elif placement.is_partial(): raise RuntimeError( diff --git a/torch/distributed/tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py index acf15c6c0ea4..0adaa2e4ad08 100644 --- a/torch/distributed/tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -28,7 +28,7 @@ PlacementList = list[Optional[Placement]] -# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould +# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type should # be the same set of possibilities. OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]] @@ -73,12 +73,37 @@ class OpSpec: invariant: the DeviceMesh on all DTensorSpec must be the same """ + # output_specs and input_specs are related: for this op, given these input_specs, + # this is the way the output would look output_specs: Union[DTensorSpec, tuple[Optional[DTensorSpec], ...]] input_specs: Optional[Sequence[DTensorSpec]] = None - # redistribute costs to redistribute the operator input shardings to this OpSpec. - # Note that We need a nested list to record the cost for each operand of this - # operator, and for each operand of this operator it might have multiple OpSpecs. + """ + redistribute_cost tells how expensive it is to redistribute a given input into the + placement specified in this OpSpec. + + outer list: one entry (list) per (tensor) input in the op's arg schema + inner list: one entry (cost value) per possible sharding spec for that input + + Example: + ------- + another_op() -> tensor_a # another_op produces the output that becomes our first input + my_op(tensor_a) + + Let's assume this OpSpec's input_specs are [Replicate()], + but another_op() supports 2 strategies (OpSpecs) which produce outputs of + Replicate() + Shard(0) + + In this example, redistribute_costs would look like this + [ + # one row representing "my_op's first input" (tensor_a) + [ + # two entries, one for each strategies supported by another_op + 0.0, # cost of redistributing tensor_a from 'Replicate()' + K, # cost of redistributing tensor_a from 'Shard(0)' + ], + """ redistribute_cost: Optional[list[list[float]]] = None @cached_property @@ -317,15 +342,6 @@ def __str__(self) -> str: args_schema.append(str(arg)) return f"Op(op={self.op}, args_schema={', '.join(args_schema)} @ mesh: {mesh_shape})" - def __post_init__(self) -> None: - has_symints = False - for a in self.args_schema: - if isinstance(a, DTensorSpec) and a.tensor_meta is not None: - if any(isinstance(s, torch.SymInt) for s in a.tensor_meta.shape): - has_symints = True - break - self.has_symints = has_symints - def arg_type_tensor_or_tensor_list_like(self, arg_idx: int) -> bool: arg = self.args_schema[arg_idx] is_tensor = isinstance(arg, DTensorSpec) @@ -345,6 +361,13 @@ def return_type_tuple_tensor_like(self) -> bool: return_types[0].type, torch.TensorType ) + def return_type_list_tensor_like(self) -> bool: + # returns True if the return type is a List + return_types = self.op._schema.returns + return len(return_types) == 1 and isinstance( + return_types[0].type, torch.ListType + ) + def return_type_tensor(self) -> bool: return_types = self.op._schema.returns # all dispatch ops only return Tensor or Tuple[Tensor] for tensor like diff --git a/torch/distributed/tensor/_ops/_embedding_ops.py b/torch/distributed/tensor/_ops/_embedding_ops.py index 9d316aff4ed8..1b8e47895ce5 100644 --- a/torch/distributed/tensor/_ops/_embedding_ops.py +++ b/torch/distributed/tensor/_ops/_embedding_ops.py @@ -113,7 +113,7 @@ def _partition_value( def _reduce_value( self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int ) -> torch.Tensor: - # by the time we ned reduction, we should have already saved the mask + # by the time we need reduction, we should have already saved the mask assert self.mask_buffer.data is not None # apply the mask to the tensor that pending reduction @@ -134,7 +134,7 @@ def _reduce_shard_value( mesh_dim: int, shard_spec: Placement, ) -> torch.Tensor: - # by the time we ned reduction, we should have already saved the mask + # by the time we need reduction, we should have already saved the mask assert self.mask_buffer.data is not None # apply the mask to the tensor that pending reduction diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index 49df9c63a9ea..9e875936c264 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -904,34 +904,48 @@ def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy: return output_strategy -@register_op_strategy( - [aten.native_layer_norm_backward.default], - schema_info=RuntimeSchemaInfo(2), -) -def layer_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy: +def _common_norm_backward_strategy( + op_schema: OpSchema, + rms_norm: bool = False, +) -> OpStrategy: + """Common backward strategy logic for layer_norm and rms_norm.""" # backward op does not need to validate the mesh since forward op has already done it mesh = op_schema.get_mesh_from_args(validate=False) - # args must be: grad_out, input, normalized_shape, mean, rstd, - # weight, bias, output_mask. For None weight and bias, their - # corresponding objects will be None as well. - - assert len(op_schema.args_schema) == 8 - ( - grad_out_strategy, - input_strategy, - normalized_shape, - mean_strategy, - rstd_strategy, - weight_strategy, - bias_strategy, - output_mask, - ) = op_schema.args_schema + if not rms_norm: + # layer_norm args: grad_out, input, normalized_shape, mean, rstd, + # weight, bias, output_mask. For None weight and bias, their + # corresponding objects will be None as well. + assert len(op_schema.args_schema) == 8 + ( + grad_out_strategy, + input_strategy, + normalized_shape, + mean_strategy, + rstd_strategy, + weight_strategy, + bias_strategy, + output_mask, + ) = op_schema.args_schema + else: + # rms_norm args: grad_out, input, normalized_shape, rstd, + assert len(op_schema.args_schema) == 6 + ( + grad_out_strategy, + input_strategy, + normalized_shape, + rstd_strategy, + weight_strategy, + output_mask, + ) = op_schema.args_schema + mean_strategy = None + bias_strategy = None assert isinstance(grad_out_strategy, OpStrategy) assert isinstance(input_strategy, OpStrategy) - assert isinstance(mean_strategy, OpStrategy) assert isinstance(rstd_strategy, OpStrategy) + if mean_strategy is not None: + assert isinstance(mean_strategy, OpStrategy) assert isinstance(normalized_shape, (int, Sequence, torch.Size)) normalized_size = normalize_to_torch_size(normalized_shape) @@ -939,9 +953,12 @@ def layer_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy: axis = input_ndim - len(normalized_size) outer_dims = list(range(axis)) - assert isinstance(output_mask, list) and len(output_mask) == 3 + if not rms_norm: + assert isinstance(output_mask, list) and len(output_mask) == 3 + else: + assert isinstance(output_mask, list) and len(output_mask) == 2 - # output triple: (d_input, d_weight, d_bias) + # output tuple: (d_input, d_weight[, d_bias]) out_tuple_strategy = OpStrategy([]) for idx, input_placement_strategy in enumerate(input_strategy.strategies): # args for OpSpec @@ -982,10 +999,14 @@ def layer_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy: generate_redistribute_costs(input_strategy, input_target_spec) ) - # arg: mean, rstd - mean_src_spec = mean_strategy.strategies[idx].output_spec - input_specs_list.append(mean_src_spec) - redistribute_costs.append([0.0 for _ in mean_strategy.strategies]) + # arg: mean + if not rms_norm: + assert mean_strategy is not None # mypy fix + mean_src_spec = mean_strategy.strategies[idx].output_spec + input_specs_list.append(mean_src_spec) + redistribute_costs.append([0.0 for _ in mean_strategy.strategies]) + + # arg: rstd rstd_src_spec = rstd_strategy.strategies[idx].output_spec input_specs_list.append(rstd_src_spec) redistribute_costs.append([0.0 for _ in rstd_strategy.strategies]) @@ -1001,6 +1022,7 @@ def _add_target_input_spec(strategy) -> DTensorSpec: # arg: weight # d_weight = sum(grad_out * (input - mean) / rstd, outer_dim, keepdim=False) + # For RMS norm, mean is 0, so it's just: sum(grad_out * input / rstd, outer_dim, keepdim=False) if weight_strategy is not None: weight_src_spec = _add_target_input_spec(weight_strategy) # TODO: now d_weight spec follows input spec w/ a reduction. @@ -1020,36 +1042,39 @@ def _add_target_input_spec(strategy) -> DTensorSpec: ) output_specs_list.append(weight_out_spec if output_mask[1] else None) else: - assert output_mask[1] is False, ( - "output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward." - ) + if not rms_norm: + error_msg = "output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward." + else: + error_msg = "output_mask[1] should not be `True` while weight argument is `None` in _fused_rms_norm_backward." + assert output_mask[1] is False, error_msg output_specs_list.append(None) # arg: bias # d_bias = sum(grad_out, outer_dim, keepdim=False) - if bias_strategy is not None: - bias_src_spec = _add_target_input_spec(bias_strategy) - # d_bias spec follows a reduction over grad_out - inp_placements = _replicate_dims_start_at( - grad_out_target_spec.placements, axis - ) - reduce_dims_map = _infer_reduce_dims_map( - outer_dims, grad_out_target_spec.ndim, False - ) - out_placements = map_placements_after_reduction( - inp_placements, outer_dims, reduce_dims_map, "sum" - ) - bias_out_spec = DTensorSpec( - mesh=mesh, - placements=out_placements, - tensor_meta=bias_src_spec.tensor_meta, - ) - output_specs_list.append(bias_out_spec if output_mask[2] else None) - else: - assert output_mask[2] is False, ( - "output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward." - ) - output_specs_list.append(None) + if not rms_norm: + if bias_strategy is not None: + bias_src_spec = _add_target_input_spec(bias_strategy) + # d_bias spec follows a reduction over grad_out + inp_placements = _replicate_dims_start_at( + grad_out_target_spec.placements, axis + ) + reduce_dims_map = _infer_reduce_dims_map( + outer_dims, grad_out_target_spec.ndim, False + ) + out_placements = map_placements_after_reduction( + inp_placements, outer_dims, reduce_dims_map, "sum" + ) + bias_out_spec = DTensorSpec( + mesh=mesh, + placements=out_placements, + tensor_meta=bias_src_spec.tensor_meta, + ) + output_specs_list.append(bias_out_spec if output_mask[2] else None) + else: + assert output_mask[2] is False, ( + "output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward." + ) + output_specs_list.append(None) out_tuple_strategy.strategies.append( OpSpec( @@ -1062,6 +1087,22 @@ def _add_target_input_spec(strategy) -> DTensorSpec: return out_tuple_strategy +@register_op_strategy( + [aten.native_layer_norm_backward.default], + schema_info=RuntimeSchemaInfo(2), +) +def layer_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy: + return _common_norm_backward_strategy(op_schema) + + +@register_op_strategy( + [aten._fused_rms_norm_backward.default], + schema_info=RuntimeSchemaInfo(2), +) +def fused_rms_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy: + return _common_norm_backward_strategy(op_schema, rms_norm=True) + + @register_op_strategy( [aten.topk.default], schema_info=RuntimeSchemaInfo(2), @@ -1085,8 +1126,32 @@ def topk_strategy(op_schema: OpSchema) -> OpStrategy: if dim != topk_dim: dim_shardings: PlacementList = [Shard(dim)] * 3 single_mesh_dim_strategies.append(dim_shardings) - # TODO: topk on sharded dim requries non-trival reduction, address it later + # TODO: topk on sharded dim requires non-trival reduction, address it later return expand_to_full_mesh_op_strategy( input_strategy.mesh, op_schema, single_mesh_dim_strategies, input_index=2 ) + + +@register_op_strategy( + [aten.histc.default], + # strategy choice depends on the value of 'min' and 'max' kwargs, which are position 2 and 3 + schema_info=RuntimeSchemaInfo(2), +) +def histc_strategy(op_schema: OpSchema) -> OpStrategy: + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + single_mesh_dim_strategies: list[PlacementList] = [] + single_mesh_dim_strategies.append([Replicate(), Replicate()]) + + # histc can support sharded input and partial output on any input dim, provided the min and max + # values are user-specified. If not user-specified, the true min and max of the data in each local + # tensor will be used to compute bin boundaries, which will not be the same across ranks, leading to + # an incorrect final result + if len(op_schema.args_schema) == 4: + for dim in range(input_strategy.ndim): + dim_shardings: PlacementList = [Partial(), Shard(dim)] + single_mesh_dim_strategies.append(dim_shardings) + + return expand_to_full_mesh_op_strategy( + input_strategy.mesh, op_schema, single_mesh_dim_strategies + ) diff --git a/torch/distributed/tensor/_ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py index b7804d318104..fa4446a2d15e 100644 --- a/torch/distributed/tensor/_ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -6,7 +6,7 @@ import torch from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor._op_schema import ( OpSchema, OpSpec, @@ -24,6 +24,10 @@ prod, register_op_strategy, ) +from torch.distributed.tensor._utils import ( + compute_local_shape_and_global_offset, + compute_local_stride, +) from torch.distributed.tensor.placement_types import ( Partial, Placement, @@ -700,7 +704,7 @@ def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrate None, # max_k None, # philox_seed None, # philox_offset - # NOTE: debug_attn_mask is not supproted by pytorch and is always an empty tensor + # NOTE: debug_attn_mask is not supported by pytorch and is always an empty tensor # https://github.com/pytorch/pytorch/blob/60205b0eb2602317856312a66d955c88334ade0b/aten/src/ATen/native/transformers/cuda/attention.cu#L839-L840 debug_attn_mask_sharding, # debug_attn_mask Replicate(), # q @@ -1035,6 +1039,51 @@ def grouped_mm_strategy(op_schema: OpSchema) -> OpStrategy: ] ) + def valid_grouped_mm_strides( + input_specs: list[DTensorSpec], output_specs: tuple[Optional[DTensorSpec], ...] + ) -> bool: + # 1. compute the local-tensor shape/strides given this sharding proposal + # 2. apply the logic from the groped_mm meta function + # UGH the input DTensorSpecs are missing their tensormetas... so i can get them another way + def local_meta(spec: OpSpec, placements: tuple[Placement, ...]) -> TensorMeta: + assert isinstance(spec.output_specs, DTensorSpec) + assert isinstance(spec.output_specs.tensor_meta, TensorMeta) + meta: TensorMeta = spec.output_specs.tensor_meta + local_stride = compute_local_stride(meta.stride, mesh, placements) + local_shape, _ = compute_local_shape_and_global_offset( + meta.shape, mesh, placements + ) + return TensorMeta(torch.Size(local_shape), local_stride, meta.dtype) + + mat1_meta = local_meta(mat1_strategy.strategies[0], input_specs[0].placements) + mat2_meta = local_meta(mat2_strategy.strategies[0], input_specs[1].placements) + + def check_valid_strides(meta: TensorMeta) -> bool: + # copied from `_meta_grouped_mm_common` in meta_registrations.py + end_dim = len(meta.shape) - 1 + alignment = 16 // meta.dtype.itemsize + if meta.stride[end_dim - 1] == 1 and meta.stride[end_dim] >= max( + 1, meta.shape[end_dim - 1] + ): + if not meta.stride[end_dim] % alignment == 0: + return False + elif meta.stride[end_dim] == 1 and meta.stride[end_dim - 1] >= max( + 1, meta.shape[end_dim] + ): + if not meta.stride[end_dim - 1] % alignment == 0: + return False + else: + return False + return True + + mat1_valid = check_valid_strides(mat1_meta) + mat2_valid = check_valid_strides(mat2_meta) + return mat1_valid and mat2_valid + return expand_to_full_mesh_op_strategy( - mesh, op_schema, single_mesh_dim_strategies, input_index=1 + mesh, + op_schema, + single_mesh_dim_strategies, + input_index=1, + is_valid_strategy_cb=valid_grouped_mm_strides, ) diff --git a/torch/distributed/tensor/_ops/_pointwise_ops.py b/torch/distributed/tensor/_ops/_pointwise_ops.py index d50622649983..46fc8fbc0d99 100644 --- a/torch/distributed/tensor/_ops/_pointwise_ops.py +++ b/torch/distributed/tensor/_ops/_pointwise_ops.py @@ -134,8 +134,14 @@ aten.ceil.out, aten.ceil_.default, aten.clamp.default, + aten.clamp.Tensor, aten.clamp.out, aten.clamp_.default, + aten.clamp_.Tensor, + aten.clamp_min.default, + aten.clamp_min.Tensor, + aten.clamp_max.default, + aten.clamp_max.Tensor, aten.clip.default, aten.clip.out, aten.clip_.default, diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index fe957d2ccab6..3b6b8c33cdbd 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -4,6 +4,7 @@ from typing import cast, Optional import torch +from torch._prims_common import IntLike from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._op_schema import ( OpSchema, @@ -34,59 +35,78 @@ Shard, ) +from ._pointwise_ops import pointwise_strategy + aten = torch.ops.aten -def default_strategy(op_schema: OpSchema) -> StrategyType: - # Default strategy by default just propagate the first input strategy - select_strategy = op_schema.args_schema[0] - assert isinstance(select_strategy, OpStrategy) - # we create new DTensorSpecs even for default strategy to assure that - # the tensor metas are distinct between the arguments and outputs - input_specs = [] - redistribute_cost = [] - for i in op_schema.args_schema: - input_specs.append( - DTensorSpec( - mesh=select_strategy.mesh, - placements=select_strategy.strategies[0].output_spec.placements, - tensor_meta=select_strategy.strategies[0].output_spec.tensor_meta, +def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType: + # For ops with a single tensor input, we perform a 1:1 mapping such that + # for each strategy that the input supports, we create a corresponding strategy. + # Note: this may be a complete waste of work, because it should be equivalent to + # `return first_input_strategy` (unless creating a deep copy is important for some reason) + assert len([s for s in op_schema.args_schema if isinstance(s, OpStrategy)]) == 1, ( + "propagate_single_input_strategy only works for single-tensor-input ops" + ) + first_input_strategy = op_schema.args_schema[0] + assert isinstance(first_input_strategy, OpStrategy) + return OpStrategy( + [ + OpSpec( + output_specs=DTensorSpec( + mesh=first_input_strategy.mesh, + placements=strategy.output_spec.placements, + tensor_meta=strategy.output_spec.tensor_meta, + ), + input_specs=[ + DTensorSpec( + mesh=first_input_strategy.mesh, + placements=strategy.output_spec.placements, + tensor_meta=strategy.output_spec.tensor_meta, + ) + ], + redistribute_cost=[ + generate_redistribute_costs( + first_input_strategy, strategy.output_spec + ) + ], ) - ) - redistribute_cost.append([0.0] * len(select_strategy.strategies)) - - default_strategy = [ - OpSpec( - output_specs=DTensorSpec( - mesh=select_strategy.mesh, - placements=strategy.output_spec.placements, - tensor_meta=strategy.output_spec.tensor_meta, - ), - input_specs=input_specs, - redistribute_cost=redistribute_cost, - ) - for strategy in select_strategy.strategies - ] - return OpStrategy(default_strategy) + for strategy in first_input_strategy.strategies + ] + ) register_op_strategy( [ aten.clone.default, aten.contiguous.default, - aten.copy_.default, aten.detach.default, aten.fill_.Scalar, aten.view.dtype, aten.zero_.default, ] -)(default_strategy) +)(propagate_single_input_strategy) register_op_strategy( aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]) -)(default_strategy) +)(propagate_single_input_strategy) + +# copy_ is actually a pointwise op with broadcasting, so reuse the pointwise strategy, which takes care of these +# requirements. +# +# Following torch broadcasting semantics (https://docs.pytorch.org/docs/stable/notes/broadcasting.html) +# - self can not change shape as a result of broadcasting since this is an inplace op +# - src can broadcast, but when it does it always does so from the trailing end +# e.g. the last dim of 'src' must match up with the last dim of 'self' +# +# DTensor semantics for inplace ops also dictates that we may NOT redistribute our 'self' input. +# In practice, what this means is +# - our output strategies should map 1:1 to our 'self' input strategies +# - our 'src' input may be redistributed to match up with the 'self' input, with the caveat of adjusting for +# broadcasting dim +register_op_strategy(aten.copy_.default)(pointwise_strategy) @register_op_strategy( @@ -376,14 +396,14 @@ def gen_slice_strategy(op_schema: OpSchema) -> StrategyType: start = 0 if end is None or end > input_shape[dim]: end = input_shape[dim] - assert isinstance(start, int) - assert isinstance(end, int) - assert isinstance(step, int) + assert isinstance(start, IntLike) + assert isinstance(end, IntLike) + assert isinstance(step, IntLike) # normalize args - slice_dim = normalize_dim(dim, input_ndim) - start = normalize_dim(start, input_shape[dim]) - end = normalize_dim(end, input_shape[dim]) + slice_dim = normalize_dim(dim, input_ndim) # type: ignore[arg-type] + start = normalize_dim(start, input_shape[dim]) # type: ignore[arg-type] + end = normalize_dim(end, input_shape[dim]) # type: ignore[arg-type] redundant_slice = start == 0 and end == input_shape[dim] and step == 1 @@ -1073,7 +1093,7 @@ def place(vp: Placement, ip: Placement) -> Placement: ], RuntimeSchemaInfo(1), ) -def split_strategy(op_schema: OpSchema) -> TupleStrategy: +def split_strategy(op_schema: OpSchema) -> OpStrategy: input_strategy = op_schema.args_schema[0] split_size_or_sections = op_schema.args_schema[1] assert isinstance(input_strategy, OpStrategy) @@ -1096,23 +1116,27 @@ def size_split(N, i) -> list: ) assert isinstance(output_size_list, Sized) - split_strategies = [] - - for _ in range(len(output_size_list)): - op_strategy = OpStrategy([]) - - for strategy in input_strategy.strategies: - spec = strategy.output_spec - placements = spec.placements - if is_tensor_dim_sharded(spec, dim=dim): - # if the input is sharded on the split dim, we need to unshard it - placements = unshard_tensor_dim(spec.placements, dim=dim) - - spec = DTensorSpec(spec.mesh, placements) - - op_strategy.strategies.append( - OpSpec(output_specs=spec, input_specs=([spec])) + all_strategies = [] + for strategy in input_strategy.strategies: + spec = strategy.output_spec + placements = spec.placements + if is_tensor_dim_sharded(spec, dim=dim): + # if the input is sharded on the split dim, we need to unshard it + placements = unshard_tensor_dim(spec.placements, dim=dim) + + input_spec = DTensorSpec(spec.device_mesh, placements, spec.tensor_meta) + output_specs = tuple( + DTensorSpec(spec.device_mesh, placements) + for _ in range(len(output_size_list)) + ) + all_strategies.append( + OpSpec( + output_specs=output_specs, + input_specs=(input_spec,), + redistribute_cost=[ + generate_redistribute_costs(input_strategy, input_spec) + ], ) - split_strategies.append(op_strategy) + ) - return TupleStrategy(split_strategies) + return OpStrategy(all_strategies) diff --git a/torch/distributed/tensor/_ops/_view_ops.py b/torch/distributed/tensor/_ops/_view_ops.py index 8fe213f39846..c942da67cd8a 100644 --- a/torch/distributed/tensor/_ops/_view_ops.py +++ b/torch/distributed/tensor/_ops/_view_ops.py @@ -300,7 +300,7 @@ def view_groups(from_size: Shape, to_size: Shape) -> DimMap: Flatten((InputDim(1), InputDim(2))) ) - - ouptut dimension 0 maps to input dimension 0 + - output dimension 0 maps to input dimension 0 - output dimension 1 maps to a flattened input dimensions 1 and 2 diff --git a/torch/distributed/tensor/_ops/utils.py b/torch/distributed/tensor/_ops/utils.py index 5215795b085d..8e07d0d6c1f7 100644 --- a/torch/distributed/tensor/_ops/utils.py +++ b/torch/distributed/tensor/_ops/utils.py @@ -134,6 +134,8 @@ def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: for i, placement in enumerate(spec.placements): if placement.is_shard(): shard_dim = cast(Shard, placement).dim + if shard_dim >= len(shape): + return False shards_map[shard_dim] *= spec.mesh.size(i) for i, dim_size in enumerate(shape): @@ -216,7 +218,7 @@ def map_placements_after_broadcast( # the input shape shard dim before broadcasting, # in this case it means implicit broadcasting happen # in this dim, so we can just mark it as replicate - # and implict broadcast will broadcast automatically + # and implicit broadcast will broadcast automatically # to the sharded shape new_placements.append(Replicate()) @@ -226,6 +228,11 @@ def map_placements_after_broadcast( def generate_redistribute_costs( src_strategy: OpStrategy, dst_spec: DTensorSpec ) -> list[float]: + """Generates one row in the 'redistribute_costs' matrix in an OpSpec + The length of the returned list will match the number of strategies in 'src_strategy'. + + Each value in the row is the cost of redistributing from a particular src_strategy to dst_spec. + """ redistribute_costs: list[float] = [ redistribute_cost(strat.output_spec, dst_spec) for strat in src_strategy.strategies @@ -241,7 +248,36 @@ def expand_to_full_mesh_op_strategy( *, input_index: int = 1, inplace_op: bool = False, + is_valid_strategy_cb: Optional[ + Callable[[list[DTensorSpec], tuple[Optional[DTensorSpec], ...]], bool] + ] = None, ) -> OpStrategy: + """ + Convenience function to allow writing a sharding strategy considering only a single mesh dimension, + and have it expanded combinatorically to all mesh dimensions. + + Args: + mesh (DeviceMesh): the device mesh to expand the strategy to + op_schema (OpSchema): the op schema + single_mesh_dim_strategies (list[PlacementList]): the sharding strategies to expand. The outer list is over + different strategies. The inner PlacementList is over the outputs and inputs of the op. If input_index is 1, + a PlacementList looks like [output_placement, input_placement1, input_placement2, ...]. + input_index: the number of outputs of the op, defaults to 1 + inplace_op: whether the op is inplace or not, defaults to False + is_valid_strategy_cb: a callback function to filter out invalid sharding rules, defaults to None. + + Example: Let's say `my_op(tensor_x, tensor_y) - > output_tensor` can support sharding or replicating tensor_x, + but always requires tensor_y to be replicated. We can specify these valid combinations ignoring mesh dims. + Then, we can rely on `expand_to_full_mesh_op_strategy` to create every possible combination of these shardings + over multiple mesh dimensions, filtering out any combinations that are invalid based on the actual mesh dim size. + + single_mesh_dim_strategies = [ + # first strategy: return output sharded on first dim, shard tensor_x on its first dim, replicate tensor_y + [Shard(0), Shard(0), Replicate()] + # second strategy: replicate output, and both inputs + [Replicate(), Replicate(), Replicate()] + ] + """ # Expand the single_mesh_dim_strategies to full mesh dim strategies. all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim @@ -252,6 +288,7 @@ def expand_to_full_mesh_op_strategy( spec_list: list[Optional[DTensorSpec]] = [] for specs in zip(*strategy_comb): if specs[0] is not None: + # TODO: we should fill in tensor_meta here. If nothing else, it helps the filter strategy callback spec_list.append(DTensorSpec(mesh, specs)) else: spec_list.append(None) @@ -269,30 +306,36 @@ def expand_to_full_mesh_op_strategy( # input_spec matches the first argument's runtime sharding, otherwise we skip continue - # check inputs shardable - inputs_shardable = all( + output_specs: tuple[Optional[DTensorSpec], ...] + if input_index > 1: + output_specs = tuple(spec_list[:input_index]) + else: + if spec_list[0] is not None: + output_specs = spec_list[0] # type: ignore[assignment] + else: + raise RuntimeError("output spec is None") + + # check all inputs are shardable + if not all( is_tensor_shardable(inp.shape, s) for inp, s in zip(input_args_strategy, input_specs) - ) + ): + continue - # only add to the all_strategies list when all inputs are shardable - if inputs_shardable: - redistribute_cost = [ - generate_redistribute_costs(input_strategy, input_spec) - for input_strategy, input_spec in zip(input_args_strategy, input_specs) - ] - if input_index > 1: - output_specs = tuple(spec_list[:input_index]) - else: - if spec_list[0] is not None: - output_specs = spec_list[0] # type: ignore[assignment] - else: - raise RuntimeError("output spec is None") - strategy = OpSpec( - output_specs=output_specs, - input_specs=input_specs, - redistribute_cost=redistribute_cost, - ) - all_strategies.append(strategy) + # perform additional op-specific filtering + if is_valid_strategy_cb is not None: + if not is_valid_strategy_cb(input_specs, output_specs): + continue + redistribute_cost = [ + generate_redistribute_costs(input_strategy, input_spec) + for input_strategy, input_spec in zip(input_args_strategy, input_specs) + ] + + strategy = OpSpec( + output_specs=output_specs, + input_specs=input_specs, + redistribute_cost=redistribute_cost, + ) + all_strategies.append(strategy) return OpStrategy(all_strategies) diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index 4b1536644b87..32ab4943b5f0 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -8,6 +8,7 @@ import torch from torch._ops import OpOverload from torch._subclasses import FakeTensorMode +from torch.distributed._functional_collectives import _are_we_tracing from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor._op_schema import ( OpInfo, @@ -101,7 +102,49 @@ def register_op_strategy( schema_info: Optional[RuntimeSchemaInfo] = None, ): """ - Register a sharding strategy generator for an operator. + Register a :class:`OpStrategy` generator for an operator. + + During the sharding propagation, DTensor wants to enumerate all + acceptable sharding specs (:class:`OpSpec`) for an operator, + and by "acceptable" we mean that the operator can be executed on + the ``_local_tensor`` of DTensor args/kwargs (with ``OpSpec.input_specs``) + and the output(s) constitute valid DTensor(s) (with ``OpSpec.output_specs``). + + ``strategy_func`` is the function that enumerates such acceptable specs + for the operator ``op_overload``. One general approach to write ``strategy_func`` + is, if the operator has simple arguments structure (e.g. mm, bmm), first enumerating + all sharding specs for the operands, and then filtering out the ones that + are not valid. For example, for ``mm``, the operands are two 2D tensors, and + if both ``input`` and ``mat2`` have sharding placements ``[Shard(0)]``, then this + is not an acceptable ``input_specs``. + + Once we have a way to enumerate all acceptable sharding specs, we can use each + of them to construct a :class:`OpSpec`. The ``OpSpec.input_specs`` directly comes + from the sharding spec, and the ``OpSpec.output_specs`` is therefore determined + (e.g. ``[Shard(1)]`` @ ``[Shard(0)]`` yields ``[Partial()]``). In addition, + :class:`OpSpec` also contains ``redistribute_cost`` which records the redistribution + cost from each :class:`OpSpec` in the source :class:`OpStrategy.strategies` to + the target sharding spec, for each operand. + + The ``strategy_func`` should return a :class:`OpStrategy` which contains a list of + all the :class:`OpSpec`s generated in the above. + + The optional ``schema_info`` tells which non-DTensor args/kwargs could affect the + cache and whether ``pytree`` is needed to flatten the nested args. ``static_argnum`` + marks the starting index of the non-DTensor args that should be hashed into the + sharding propagation hash key, and ``static_kwargkey`` marks the keys of the + non-DTensor kwargs that should be hashed. ``needs_pytree`` should be used when + the input arg has :class:`list` or :class:`dict` structure. + + For example, ``aten.cat.default`` op has a ``List[Tensor]`` argument ``tensors`` + and an ``int`` argument ``dim``. Because ``dim`` affects the sharding propagation + result, we want to pass ``RuntimeSchemaInfo(static_argnum=1)`` because the argument + index of ``dim`` is 1. Besides, we also want to set ``needs_pytree=True`` because + ``tensors`` needs be flattened in sharding propagation. Another example is + ``aten.histc.default``. ``histc`` has 4 arguments (self, bins, min, max) and the + last two would affect sharding propagation along with the :class:`DTensor` argument + ``self``. Since the argument index of ``min`` is 2, the `schema_info` should be + `RuntimeSchemaInfo(static_argnum=2)`. """ self.op_strategy_funcs[op_overload] = strategy_func if schema_info is not None: @@ -258,8 +301,9 @@ def propagate(self, op_info: OpInfo) -> None: # We cannot use an lru cache if we know that inputs will have dynamic shapes, # because SymInts are not hashable. # This is generally ok because this only happens during tracing in torch.compile, - # and tracing does not need to be as fast as eagermode DTensor usages. - if op_info.schema.has_symints: + # and compile autograd initial tracing, which do not need to be as fast as + # eagermode DTensor usages. + if _are_we_tracing(): output_sharding = self.propagate_op_sharding_non_cached(op_info.schema) else: output_sharding = cast( @@ -353,7 +397,10 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin for _ in range(len(op_schema.op._schema.returns)) ] ) - elif op_schema.return_type_tensor(): + elif ( + op_schema.return_type_tensor() + or op_schema.return_type_list_tensor_like() + ): output_specs = output_strategy.output_specs else: output_specs = None diff --git a/torch/distributed/tensor/_shards_wrapper.py b/torch/distributed/tensor/_shards_wrapper.py index 30cc25ae89a6..a3798eac4ae0 100644 --- a/torch/distributed/tensor/_shards_wrapper.py +++ b/torch/distributed/tensor/_shards_wrapper.py @@ -27,7 +27,7 @@ class LocalShardsWrapper(torch.Tensor): """ A wrapper class to hold local shards of a DTensor. - This class is used largely for checkpointing purposes and implicity subtypes + This class is used largely for checkpointing purposes and implicitly subtypes the _Checkpointable protocol. """ @@ -159,7 +159,7 @@ def handle_view(args, kwargs) -> "LocalShardsWrapper": ] elif args[0].local_shards()[0].ndim == 1: assert args[0].storage_metadata().size[0] == view_shape[0] - # This case is for optimizer sharding as regardles of sharding type, optimizer state is row wise sharded + # This case is for optimizer sharding as regardless of sharding type, optimizer state is row wise sharded res_shards_list = [ aten.view.default(shard, shard.shape, **kwargs) for shard in args[0].local_shards() diff --git a/torch/distributed/tensor/_utils.py b/torch/distributed/tensor/_utils.py index 92ea70eb16a8..6521eeac9b3e 100644 --- a/torch/distributed/tensor/_utils.py +++ b/torch/distributed/tensor/_utils.py @@ -296,7 +296,7 @@ def compute_global_tensor_shape( for shape_tensor in gathered_shaped_tensors: if not torch.equal(local_shape[other_dims], shape_tensor[other_dims]): raise RuntimeError( - "Non-sharded dimentions should have identical size across ranks." + "Non-sharded dimensions should have identical size across ranks." ) shape_tensor_list = shape_tensor.tolist() sharded_dim_sum += shape_tensor_list[shard_dim] diff --git a/torch/distributed/tensor/debug/_comm_mode.py b/torch/distributed/tensor/debug/_comm_mode.py index 570161b67682..99978f9cc6b5 100644 --- a/torch/distributed/tensor/debug/_comm_mode.py +++ b/torch/distributed/tensor/debug/_comm_mode.py @@ -395,7 +395,7 @@ def add_json_information(json_dict, fqn): json_dict: dict[str, Any] = {} add_json_information(json_dict, "Global") - # converts dictonary into json file + # converts dictionary into json file with open(file_name, "w") as json_file: json.dump(json_dict, json_file, indent=4) diff --git a/torch/distributed/tensor/examples/comm_mode_features_example.py b/torch/distributed/tensor/examples/comm_mode_features_example.py index da004aef4071..8625a3f7dd1d 100644 --- a/torch/distributed/tensor/examples/comm_mode_features_example.py +++ b/torch/distributed/tensor/examples/comm_mode_features_example.py @@ -27,11 +27,10 @@ def get_device_type() -> str: - return ( - "cuda" - if torch.cuda.is_available() and torch.cuda.device_count() >= 4 - else "cpu" - ) + device_type = "cpu" + if torch.accelerator.device_count() >= 4: + device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") + return device_type c10d_functional = torch.ops.c10d_functional @@ -711,7 +710,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def run_example(world_size: int, rank: int, example_name: str) -> None: # set manual seed - # intializing class with all of the functions + # initializing class with all of the functions instantiated_example = CommDebugModeExample(world_size, rank) # dict that stores example code function names name_to_example_code: dict[str, Callable[[], None]] = { diff --git a/torch/distributed/tensor/examples/convnext_example.py b/torch/distributed/tensor/examples/convnext_example.py index 9a3c2bbabd9e..994f2ee10f69 100644 --- a/torch/distributed/tensor/examples/convnext_example.py +++ b/torch/distributed/tensor/examples/convnext_example.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs """ The following example demonstrates how to train a ConvNeXt model -with intermediate activations sharded across mutliple GPUs via DTensor +with intermediate activations sharded across multiple GPUs via DTensor To run the example, use the following command: torchrun --standalone --nnodes=1 --nproc-per-node=4 convnext_example.py diff --git a/torch/distributed/tensor/examples/torchrec_sharding_example.py b/torch/distributed/tensor/examples/torchrec_sharding_example.py index f66ea658daf4..2c5d10413610 100644 --- a/torch/distributed/tensor/examples/torchrec_sharding_example.py +++ b/torch/distributed/tensor/examples/torchrec_sharding_example.py @@ -231,7 +231,7 @@ def run_torchrec_row_wise_uneven_sharding_example(rank, world_size): # note: for uneven sharding, we need to specify the shape and stride because # DTensor would assume even sharding and compute shape/stride based on the - # assumption. Torchrec needs to pass in this information explicitely. + # assumption. Torchrec needs to pass in this information explicitly. # shape/stride are global tensor's shape and stride dtensor = DTensor.from_local( local_shards_wrapper, # a torch.Tensor subclass @@ -324,7 +324,7 @@ def run_torchrec_table_wise_sharding_example(rank, world_size): # create a DTensor from the local shard for the current table # note: for uneven sharding, we need to specify the shape and stride because # DTensor would assume even sharding and compute shape/stride based on the - # assumption. Torchrec needs to pass in this information explicitely. + # assumption. Torchrec needs to pass in this information explicitly. dtensor = DTensor.from_local( local_shards, device_submesh, diff --git a/torch/distributed/tensor/examples/visualize_sharding_example.py b/torch/distributed/tensor/examples/visualize_sharding_example.py index 7152c928d2f2..7c0ab3adfffa 100644 --- a/torch/distributed/tensor/examples/visualize_sharding_example.py +++ b/torch/distributed/tensor/examples/visualize_sharding_example.py @@ -18,6 +18,9 @@ rank = int(os.environ["RANK"]) +device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") + + def section(msg: str) -> None: if rank == 0: rich.print(rich.rule.Rule(msg)) @@ -31,7 +34,7 @@ def visualize(t: dt.DTensor, msg: str = "") -> None: section("[bold]1D Tensor; 1D Mesh[/bold]") -m = dist.init_device_mesh("cuda", (4,)) +m = dist.init_device_mesh(device_type, (4,)) t = torch.ones(4) visualize( dt.distribute_tensor(t, m, [dt.Replicate()]), @@ -43,7 +46,7 @@ def visualize(t: dt.DTensor, msg: str = "") -> None: ) section("[bold]2D Tensor; 1D Mesh[/bold]") -m = dist.init_device_mesh("cuda", (4,)) +m = dist.init_device_mesh(device_type, (4,)) t = torch.ones(4, 4) visualize( dt.distribute_tensor(t, m, [dt.Replicate()]), @@ -59,7 +62,7 @@ def visualize(t: dt.DTensor, msg: str = "") -> None: ) section("[bold]1D Tensor; 2D Mesh[/bold]") -m = dist.init_device_mesh("cuda", (2, 2)) +m = dist.init_device_mesh(device_type, (2, 2)) t = torch.ones(4) visualize( dt.distribute_tensor(t, m, [dt.Replicate(), dt.Replicate()]), @@ -79,7 +82,7 @@ def visualize(t: dt.DTensor, msg: str = "") -> None: ) section("[bold]2D Tensor; 2D Mesh[/bold]") -m = dist.init_device_mesh("cuda", (2, 2)) +m = dist.init_device_mesh(device_type, (2, 2)) t = torch.ones(4, 4) visualize( dt.distribute_tensor(t, m, [dt.Replicate(), dt.Replicate()]), diff --git a/torch/distributed/tensor/experimental/_attention.py b/torch/distributed/tensor/experimental/_attention.py index b3a5768f6fc8..f8e0984d6514 100644 --- a/torch/distributed/tensor/experimental/_attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -43,15 +43,24 @@ class _RotateMethod(Enum): aten = torch.ops.aten logger = logging.getLogger(__name__) -_is_hip: bool = hasattr(torch.version, "hip") and torch.version.hip is not None -if _is_hip: - gcn_arch_name = torch.cuda.get_device_properties("cuda").gcnArchName - _is_ck_supported = False - for arch in ["gfx942", "gfx950"]: - if arch in gcn_arch_name: - _is_ck_supported = True - _preferred_rocm_fa_library = torch.backends.cuda.preferred_rocm_fa_library - _CK_BACKEND = torch.backends.cuda._ROCmFABackends["ck"] + +def _need_scaling() -> bool: + if hasattr(torch.version, "hip") and torch.version.hip is not None: + gcn_arch_name = torch.cuda.get_device_properties("cuda").gcnArchName + _is_ck_supported = False + for arch in ["gfx942", "gfx950"]: + if arch in gcn_arch_name: + _is_ck_supported = True + # Check the function exists + _preferred_rocm_fa_library = torch.backends.cuda.preferred_rocm_fa_library + _CK_BACKEND = torch.backends.cuda._ROCmFABackends["ck"] + # Note: it is possible that CK is selected but not compiled in the binary. + if _is_ck_supported and _preferred_rocm_fa_library() == _CK_BACKEND: + # Unsure about CK's behavior, keep logsumexp untouched + return False + return True + else: + return False class _DispatchMode(Enum): @@ -249,7 +258,7 @@ def next_buffer(self) -> torch.Tensor: class _AllGatherRotater(_RingRotater): """ - Allgather the kv and return the only the requried kv. + Allgather the kv and return the only the required kv. Only one communication will be done. """ @@ -287,7 +296,7 @@ def _create_rotater( elif method == _RotateMethod.ALL_GATHER: return _AllGatherRotater(pg, seq_dim) else: - raise NotImplementedError(f"Unkonwn method {method}") + raise NotImplementedError(f"Unknown method {method}") def _templated_ring_attention( @@ -349,12 +358,12 @@ def _templated_ring_attention( First Iteration: Both ranks perform SDPA with their local qkv pairs, similar to the no-load-balance case. This iteration corresponds to the `if` of the - (`if, `elif`, `else`) in the implemementation. + (`if, `elif`, `else`) in the implementation. Second Iteration: Rank0 now has (q0, q3) and (k1, k2); rank1 has (q1, q2) and (k0, k3). For rank0, no computation is needed for q0. However, computations for q3k1 and q3k2 are required, so only q3 is used for SDPA. This corresponds to the - `else` of the (`if`, `elif`, `else`) in the implemementation. + `else` of the (`if`, `elif`, `else`) in the implementation. For rank1, k0 is not needed for q1 and q2, so only k3 is used for SDPA. This corresponds to the `elif` of (`if`, `elif`, `else`) in the implementation. @@ -456,14 +465,8 @@ def _templated_ring_attention( is_causal=is_causal_behavior.value, **kwargs, ) - if _is_hip: # See: https://github.com/pytorch/pytorch/issues/156012 - need_scaling = True - # Note: it is possible that CK is seleted but not compiled in the binary. - if _is_ck_supported and _preferred_rocm_fa_library() == _CK_BACKEND: - # Unsure about CK's behavior, keep logsumexp untouched - need_scaling = False - if need_scaling: - logsumexp *= 0.6931471805599453 + if _need_scaling(): + logsumexp *= 0.6931471805599453 sdpa_merger.step(out, logsumexp, partial) return *sdpa_merger.results(), *rest @@ -934,7 +937,7 @@ def _distribute_function( the inputs and outputs of a function. Similar to ``distribute_module``, this API installs hooks to the ``fn`` to convert the inputs and outputs. There are two major differences between ``distribute_function`` and ``distribute_module``. - First, a function does not have parammeters and buffers, as a result, + First, a function does not have parameters and buffers, as a result, ``distribute_function`` itself won't convert any parameters/buffers but simply install the input and output hooks. The tensor conversion will happen in the hooks. Another difference is an nn.Module subclass can have several instances and each @@ -950,9 +953,9 @@ def _distribute_function( ``fn_module`` is ``torch.nn.functional``. device_mesh (:class:`DeviceMesh`): the device mesh that will be used by the input and output hooks to distribute the tensors. - input_fn (Optioinal[Callable]): the hook to distribute or convert the input + input_fn (Optional[Callable]): the hook to distribute or convert the input arguments of ``fn``. - output_fn (Optioinal[Callable]): the hook to distribute or convert the output + output_fn (Optional[Callable]): the hook to distribute or convert the output arguments of ``fn``. """ @@ -1007,7 +1010,7 @@ class _AttentionContextParallel(ParallelStyle): Applies context parallel optimizations to the attention layer. This will work for nn.MultiHeadedAttention and custom attention layers that - call F.scaled_dotproduct_attention with a simliar signature. + call F.scaled_dotproduct_attention with a similar signature. This expects the `forward` method consumes either: diff --git a/torch/distributed/tensor/experimental/_func_map.py b/torch/distributed/tensor/experimental/_func_map.py index 7eb2e72343e2..fd91328c0b37 100644 --- a/torch/distributed/tensor/experimental/_func_map.py +++ b/torch/distributed/tensor/experimental/_func_map.py @@ -112,7 +112,7 @@ def local_map( >>> row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh >>> col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh >>> - >>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor convertion + >>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor conversion >>> local_mm_allreduce_forward = local_map( >>> mm_allreduce_forward, >>> out_placements=[Replicate()], diff --git a/torch/distributed/tensor/experimental/_register_sharding.py b/torch/distributed/tensor/experimental/_register_sharding.py index f91fae4580bc..b286b151efed 100644 --- a/torch/distributed/tensor/experimental/_register_sharding.py +++ b/torch/distributed/tensor/experimental/_register_sharding.py @@ -41,7 +41,7 @@ def register_sharding(op: Union[OpOverload, list[OpOverload]]): as the original op (except that if an arg is a :class:`torch.Tensor`, it will be replaced by a tensor-like object that DTensor uses internally). The function should return a sequence of 2-tuples, each specifying acceptable output placements and its - corresponding intput placements. + corresponding input placements. Example: >>> # xdoctest: +SKIP("distributed") diff --git a/torch/distributed/tensor/parallel/_data_parallel_utils.py b/torch/distributed/tensor/parallel/_data_parallel_utils.py index 6513123e2462..c41da260a02f 100644 --- a/torch/distributed/tensor/parallel/_data_parallel_utils.py +++ b/torch/distributed/tensor/parallel/_data_parallel_utils.py @@ -30,7 +30,7 @@ def _flatten_tensor( @no_type_check def _unflatten_tensor(tensor, spec, *, device_handle=None, compute_stream=None): - # unflatten would mainly be called everytime FSDP allgather parameters. + # unflatten would mainly be called every time FSDP allgather parameters. result = DTensor.from_local( tensor, spec.mesh, diff --git a/torch/distributed/tensor/parallel/_utils.py b/torch/distributed/tensor/parallel/_utils.py deleted file mode 100644 index 0a78872f57d8..000000000000 --- a/torch/distributed/tensor/parallel/_utils.py +++ /dev/null @@ -1,67 +0,0 @@ -# mypy: allow-untyped-defs -import warnings -from typing import Union - -from torch.distributed.device_mesh import _mesh_resources -from torch.distributed.tensor import DeviceMesh -from torch.distributed.tensor.placement_types import Placement - - -try: - from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling -except Exception: - - def is_torchdynamo_compiling(): # type: ignore[misc] - return False - - -LayoutsType = Union[Placement, tuple[Placement, ...]] - - -def _deprecate_warnings(func_name: str, extra_msg: str) -> None: - """ - Inject common validation logics for `_prepare_input` funcs via this decorator. - - Include verifying that input needs to be either a :class:`Tensor` or :class:`DTensor` - and only 1D :class:`DeviceMesh` is passed in. - """ - # TODO: Will follow up with dynamo POC to make warnings.warn working with dynamo. - if not is_torchdynamo_compiling(): - warnings.warn( - f"{func_name} is deprecated and will be removed soon. {extra_msg}", - FutureWarning, - stacklevel=3, - ) - - -def _validate_tp_mesh_dim( - device_mesh: DeviceMesh, -) -> None: - """ - Check whether TP mesh dimension is valid or not. - - Args: - device_mesh (:class:`DeviceMesh`): - The `device_mesh` where we perform - Tensor Parallelism on. - - Return: - `True` if the mesh dimension - is valid, `False` otherwise. - """ - if device_mesh.ndim > 1: - raise ValueError( - f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D!" - 'If you have a 2-D or N-D device_mesh, consider passing in device_mesh["tp"]' - ) - - root_mesh = _mesh_resources.get_root_mesh(device_mesh) - # if a root mesh is not the same as device_mesh, - # meaning the device_mesh is sliced out from the root mesh. - if root_mesh and root_mesh != device_mesh: - tp_mesh_dim_in_root = _mesh_resources.get_root_mesh_dim(device_mesh) - if tp_mesh_dim_in_root != root_mesh.ndim - 1: - raise RuntimeError( - f"Found TP device_mesh on the {tp_mesh_dim_in_root} dimension of its parent mesh.", - "Currently we only support intranode TP and TP needs to be the innermost dimension on its parent mesh.", - ) diff --git a/torch/distributed/tensor/parallel/api.py b/torch/distributed/tensor/parallel/api.py index 81c005000a85..2a3369a8edda 100644 --- a/torch/distributed/tensor/parallel/api.py +++ b/torch/distributed/tensor/parallel/api.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn from torch.distributed.device_mesh import _mesh_resources, DeviceMesh -from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim from torch.distributed.tensor.parallel.style import ParallelStyle @@ -71,7 +70,6 @@ def parallelize_module( # type: ignore[return] torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module") device_mesh = device_mesh or _mesh_resources.get_current_mesh() - _validate_tp_mesh_dim(device_mesh) if parallelize_plan is None: warnings.warn( diff --git a/torch/distributed/tensor/parallel/ddp.py b/torch/distributed/tensor/parallel/ddp.py index 39ab299b4f79..7b19f9767519 100644 --- a/torch/distributed/tensor/parallel/ddp.py +++ b/torch/distributed/tensor/parallel/ddp.py @@ -36,7 +36,7 @@ def _update_module_param(param_list: list[tuple[nn.Module, str, nn.Parameter]]): def _reconstruct_dtensor(module: nn.Module, _input: Any): """ - Recontruct DTensor parameters from local tensors + Reconstruct DTensor parameters from local tensors """ param_list = [] # TODO: To add perf optimizations to this iterations diff --git a/torch/distributed/tensor/parallel/fsdp.py b/torch/distributed/tensor/parallel/fsdp.py index 5282542950c4..1b0b8cac7c76 100644 --- a/torch/distributed/tensor/parallel/fsdp.py +++ b/torch/distributed/tensor/parallel/fsdp.py @@ -326,7 +326,7 @@ def __init__(self, device_handle) -> None: super().__init__() self.compute_stream = None self.device_handle = device_handle - # we have to use the dynamo disable this way to disable dynamo as the decorater way would + # we have to use the dynamo disable this way to disable dynamo as the decorator way would # trigger build failure with torch deploy... self.post_unflatten_transform = torch._dynamo.disable( # type: ignore[method-assign] self.post_unflatten_transform diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py index a8fdd7bec1ac..b37d49bd3074 100644 --- a/torch/distributed/tensor/placement_types.py +++ b/torch/distributed/tensor/placement_types.py @@ -701,7 +701,7 @@ def _partition_value( # _partition_value: partition the value of a replicated tensor on the mesh dimension # _partition_value is the conjugate operation of _reduce_value - # - i.e. _partition_value on a sum reduce op is just a divison operation + # - i.e. _partition_value on a sum reduce op is just a division operation # - the _reduce_value on a sum reduce op would just be a sum(allreduce) operation # TODO: if the reduce_op is min/max, etc. the _partition_value should be a # different operation diff --git a/torch/export/__init__.py b/torch/export/__init__.py index 6c3c2b6f9377..3ed8a6c37883 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -1,58 +1,38 @@ -import builtins -import copy -import dataclasses -import inspect import os -import sys -import typing import warnings import zipfile -from collections.abc import Iterator -from enum import auto, Enum -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from collections.abc import Mapping +from typing import Any, Callable, Optional, Union +from typing_extensions import deprecated import torch import torch.utils._pytree as pytree -from torch.fx._compatibility import compatibility from torch.fx.passes.infra.pass_base import PassResult -from torch.fx.passes.infra.pass_manager import PassManager from torch.types import FileLike -from torch.utils._pytree import ( - FlattenFunc, - FromDumpableContextFn, - ToDumpableContextFn, - UnflattenFunc, -) - - -if TYPE_CHECKING: - # Import the following modules during type checking to enable code intelligence features, - # Do not import unconditionally, as they import sympy and importing sympy is very slow - from torch._ops import OpOverload - from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint __all__ = [ + "AdditionalInputs", "Constraint", - "Dim", - "ExportBackwardSignature", - "ExportGraphSignature", - "ExportedProgram", "CustomDecompTable", - "ModuleCallEntry", - "ModuleCallSignature", "default_decompositions", + "Dim", "dims", - "export", + "draft_export", "export_for_training", + "export", + "ExportBackwardSignature", + "ExportedProgram", + "ExportGraphSignature", + "FlatArgsAdapter", "load", + "ModuleCallEntry", + "ModuleCallSignature", "register_dataclass", "save", + "ShapesCollection", "unflatten", - "FlatArgsAdapter", "UnflattenedModule", - "AdditionalInputs", - "draft_export", ] # To make sure export specific custom ops are loaded @@ -73,12 +53,17 @@ PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] +@deprecated( + "`torch.export.export_for_training` is deprecated and will be removed in PyTorch 2.10. " + "Please use `torch.export.export` instead, which is functionally equivalent.", + category=FutureWarning, +) def export_for_training( mod: torch.nn.Module, args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, + kwargs: Optional[Mapping[str, Any]] = None, *, - dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, + dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None, strict: bool = False, preserve_module_call_signature: tuple[str, ...] = (), ) -> ExportedProgram: @@ -175,9 +160,9 @@ def export_for_training( def export( mod: torch.nn.Module, args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, + kwargs: Optional[Mapping[str, Any]] = None, *, - dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, + dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None, strict: bool = False, preserve_module_call_signature: tuple[str, ...] = (), ) -> ExportedProgram: @@ -534,9 +519,9 @@ def load( def draft_export( mod: torch.nn.Module, args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, + kwargs: Optional[Mapping[str, Any]] = None, *, - dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, + dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None, preserve_module_call_signature: tuple[str, ...] = (), strict: bool = False, ) -> ExportedProgram: diff --git a/torch/export/_draft_export.py b/torch/export/_draft_export.py index 9a9ed922c83e..9d2179fcf252 100644 --- a/torch/export/_draft_export.py +++ b/torch/export/_draft_export.py @@ -4,13 +4,13 @@ import os import re import tempfile +from collections.abc import Mapping from dataclasses import dataclass from enum import IntEnum from typing import Any, Callable, Optional, Union import torch import torch._logging._internal -import torch._logging.structured import torch.utils._pytree as pytree from torch._export.passes.insert_custom_op_guards import ( get_op_profiles, @@ -362,7 +362,7 @@ def _log_expression_created( def draft_export( mod: torch.nn.Module, args: tuple[Any, ...], - kwargs: Optional[dict[str, Any]] = None, + kwargs: Optional[Mapping[str, Any]] = None, *, dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, preserve_module_call_signature: tuple[str, ...] = (), diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 09163e3bffa8..35be163b7e94 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -49,15 +49,13 @@ ) from torch._export.verifier import SpecViolationError from torch._export.wrappers import _wrap_submodules +from torch._functorch._aot_autograd.graph_capture_wrappers import create_functional_call from torch._functorch._aot_autograd.input_output_analysis import ( _graph_input_names, _graph_output_names, ) from torch._functorch._aot_autograd.schemas import GraphSignature from torch._functorch._aot_autograd.subclass_utils import get_subclass_typing_container -from torch._functorch._aot_autograd.traced_function_transforms import ( - create_functional_call, -) from torch._functorch._aot_autograd.utils import ( create_tree_flattened_fn, register_buffer_assignment_hook, @@ -1725,6 +1723,7 @@ def _is_impure(node): gm.graph.eliminate_dead_code(_is_impure) # create graph signature + assert out_spec.spec is not None, "out_spec.spec is None!" input_names = _graph_input_names(gm) output_names = _graph_output_names(gm) sig = GraphSignature( @@ -1739,7 +1738,7 @@ def _is_impure(node): buffers_to_mutate={}, user_inputs_to_mutate={}, in_spec=in_spec, - out_spec=out_spec, # type: ignore[arg-type] + out_spec=out_spec.spec, backward_signature=None, input_tokens=[], output_tokens=[], diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index 553d2eb2bf3b..f7ae6cbf21ac 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -138,12 +138,7 @@ def _insert_copy_for_mutations( Find the all the buffers and inputs that were mutated and insert copy_ operators to reflect mutations. """ - output_node = None - for node in gm.graph.nodes: - if node.op == "output": - output_node = node - break - assert output_node is not None + output_node = gm.graph.output_node() outputs = pytree.tree_flatten(output_node.args)[0] assert len(outputs) == len(mutated_outputs) @@ -169,13 +164,13 @@ def _insert_copy_for_mutations( ) return_nodes_to_copy[return_node] = copy_node - output_args = [ + output_args = tuple( return_nodes_to_copy[node] if node in return_nodes_to_copy else node for node in user_output_nodes - ] + ) with gm.graph.inserting_before(output_node): # Only return user outputs - new_output = gm.graph.output(tuple(output_args)) + new_output = gm.graph.output(output_args) output_node.replace_all_uses_with(new_output) gm.graph.erase_node(output_node) new_output.name = output_node.name @@ -199,19 +194,18 @@ def _get_codegen( """ if forward_arg_names: names = forward_arg_names + elif ( + in_spec.type == tuple + and in_spec.num_children == 2 + and in_spec.children_specs[0].type == tuple + and in_spec.children_specs[1].type == dict + ): + # if in_spec contains the args (tuple) and kwargs (dict) + names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)] + # add kwarg names + names.extend(in_spec.children_specs[1].context) else: - if ( - in_spec.type == tuple - and in_spec.num_children == 2 - and in_spec.children_specs[0].type == tuple - and in_spec.children_specs[1].type == dict - ): - # if in_spec contains the args (tuple) and kwargs (dict) - names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)] - # add kwarg names - names.extend(in_spec.children_specs[1].context) - else: - names = [f"arg_{i}" for i in range(in_spec.num_children)] + names = [f"arg_{i}" for i in range(in_spec.num_children)] return _PyTreeCodeGen( _PyTreeInfo( @@ -228,8 +222,6 @@ def _unlift( mutated_outputs: Sequence[Optional[str]], in_spec: pytree.TreeSpec, out_spec: Optional[pytree.TreeSpec], - state_dict: dict[str, Any], - constants: dict[str, Any], forward_arg_names: Optional[list[str]] = None, ): """ @@ -427,7 +419,7 @@ def _create_stateful_graph_module( return stateful_gm -def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Module: +def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.fx.GraphModule: # TODO T206340015 if ep.verifiers[0].dialect != "TRAINING": ep = _remove_effect_tokens(ep) @@ -482,14 +474,13 @@ def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Modu ) ] + assert ep.call_spec.in_spec is not None new_gm = _unlift( new_gm, lifted_inputs, mutated_outputs, ep.call_spec.in_spec, ep.call_spec.out_spec, - ep.state_dict, - ep.constants, forward_arg_names=forward_arg_names, ) unlift_gm = _create_stateful_graph_module(new_gm, ep.range_constraints, ep) diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index f951b5818afd..ccc3660f7600 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -85,15 +85,19 @@ def __call__(self, min=None, max=None) -> "_DimHint": class Dim: """ - The `Dim` class allows users to specify dynamism in their exported programs. By marking a dimension with a `Dim`, - the compiler associates the dimension with a symbolic integer containing a dynamic range. + The ``Dim`` class allows users to specify dynamism in their exported + programs. By marking a dimension with a ``Dim``, the compiler associates the + dimension with a symbolic integer containing a dynamic range. - The API can be used in 2 ways: Dim hints (i.e. automatic dynamic shapes: `Dim.AUTO`, `Dim.DYNAMIC`, `Dim.STATIC`), - or named Dims (i.e. `Dim("name", min=1, max=2)`). + The API can be used in 2 ways: Dim hints (i.e. automatic dynamic shapes: + ``Dim.AUTO``, ``Dim.DYNAMIC``, ``Dim.STATIC``), or named Dims (i.e. + ``Dim("name", min=1, max=2)``). - Dim hints provide the lowest barrier to exportability, with the user only needing to specify if a dimension - if dynamic, static, or left for the compiler to decide (`Dim.AUTO`). The export process will automatically - infer the remaining constraints on min/max ranges and relationships between dimensions. + Dim hints provide the lowest barrier to exportability, with the user only + needing to specify if a dimension if dynamic, static, or left for the + compiler to decide (``Dim.AUTO``). The export process will automatically + infer the remaining constraints on min/max ranges and relationships between + dimensions. Example:: @@ -112,19 +116,19 @@ def forward(self, x, y): } ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes) - Here, export would raise an exception if we replaced all uses of `Dim.AUTO` with `Dim.DYNAMIC`, - as x.shape[0] is constrained to be static by the model. + Here, export would raise an exception if we replaced all uses of ``Dim.AUTO`` with ``Dim.DYNAMIC``, + as ``x.shape[0]`` is constrained to be static by the model. More complex relations between dimensions may also be codegened as runtime assertion nodes by the compiler, - e.g. (x.shape[0] + y.shape[1]) % 4 == 0, to be raised if runtime inputs do not satisfy such constraints. + e.g. ``(x.shape[0] + y.shape[1]) % 4 == 0``, to be raised if runtime inputs do not satisfy such constraints. - You may also specify min-max bounds for Dim hints, e.g. `Dim.AUTO(min=16, max=32)`, `Dim.DYNAMIC(max=64)`, + You may also specify min-max bounds for Dim hints, e.g. ``Dim.AUTO(min=16, max=32)``, ``Dim.DYNAMIC(max=64)``, with the compiler inferring the remaining constraints within the ranges. An exception will be raised if the valid range is entirely outside the user-specified range. Named Dims provide a stricter way of specifying dynamism, where exceptions are raised if the compiler infers constraints that do not match the user specification. For example, exporting the previous - model, the user would need the following `dynamic_shapes` argument:: + model, the user would need the following ``dynamic_shapes`` argument:: s0 = Dim("s0") s1 = Dim("s1", min=16) @@ -134,8 +138,9 @@ def forward(self, x, y): } ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes) - Named Dims also allow specification of relationships between dimensions, up to univariate linear relations. - For example, the following indicates one dimension is a multiple of another plus 4:: + Named Dims also allow specification of relationships between dimensions, up + to univariate linear relations. For example, the following indicates one + dimension is a multiple of another plus 4:: s0 = Dim("s0") s1 = 3 * s0 + 4 diff --git a/torch/export/experimental/__init__.py b/torch/export/experimental/__init__.py index 7f2303f66394..372eb3a29533 100644 --- a/torch/export/experimental/__init__.py +++ b/torch/export/experimental/__init__.py @@ -1,16 +1,27 @@ import copy import dataclasses import functools +import os import types import typing import typing_extensions +import zipfile +from pathlib import Path import torch +from torch.export.experimental._utils import _get_main_cpp_file, _get_make_file from torch.export.exported_program import _decompose_exported_program +_InputT = typing_extensions.ParamSpec("_InputT") +_RetT = typing.TypeVar("_RetT") + + +__all__ = [] # type: ignore[var-annotated] + + def _copy_graph_module_and_signature( - ep: torch.fx.GraphModule, + ep: torch.export.ExportedProgram, ) -> tuple[torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature]: # copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(), # and this can break placeholder names in some particular cases. @@ -28,7 +39,7 @@ def _copy_graph_module_and_signature( for old_node, new_node in zip(old_phs, new_phs): new_node.name = old_node.name - return gm, new_graph_signature # type: ignore[return-value] + return gm, new_graph_signature def _remove_detach_pass( @@ -73,18 +84,27 @@ def _export_forward_backward( return ep._update(gm, new_graph_signature) -@typing.no_type_check -def _sticky_export(forward_func, dynamic_shapes_callback=None): +def _sticky_export( + forward_func: typing.Callable[_InputT, _RetT], + dynamic_shapes_callback: typing.Optional[ + typing.Callable[ + _InputT, + typing.Union[ + list[typing.Any], dict[str, typing.Any], tuple[typing.Any, ...] + ], + ] + ] = None, +) -> typing.Callable[_InputT, _RetT]: """ Lazily export the model on first forward call. Usage: model.forward = _sticky_export(model.forward, dynamic_shapes_callback=callback) """ - model = forward_func.__self__ - original_forward = forward_func.__func__ + model = forward_func.__self__ # type: ignore[attr-defined] + original_forward = forward_func.__func__ # type: ignore[attr-defined] @functools.wraps(forward_func) - def wrapper(*args, **kwargs): + def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: # Unpatch forward to avoid recursion during export model.forward = types.MethodType(original_forward, model) @@ -99,7 +119,7 @@ def wrapper(*args, **kwargs): kwargs, dynamic_shapes=dynamic_shapes_spec, ).module() - wrapper._exported_artifact = exported + wrapper._exported_artifact = exported # type: ignore[attr-defined] finally: # Restore the wrapper after export model.forward = wrapper @@ -115,10 +135,6 @@ class _ExportMethod: fallbacks: list[torch.export.ExportedProgram] -_InputT = typing_extensions.ParamSpec("_InputT") -_RetT = typing.TypeVar("_RetT") - - class _ExportPackage: """ An export package is a collection of torch.export()-ed PyTorch models consisting of @@ -308,7 +324,8 @@ def _exporter_context(*args, **kwargs): # type: ignore[no-untyped-def] if isinstance(fn, torch.nn.Module): _exporter_context = torch._dynamo.eval_frame.OptimizedModule( # type: ignore[assignment] # noqa: F811 - fn, lambda _: _exporter_context + fn, + lambda _: _exporter_context, # type: ignore[arg-type] ) def _define_overload( @@ -333,18 +350,79 @@ def _method_overloads( for overload, ep in method_data.overloads.items(): yield f"{method}:{overload}", ep - def _compiled_and_package(self, f: torch.types.FileLike) -> None: - options = { + def _compiled_and_package( + self, + f: torch.types.FileLike, + standalone: bool = False, + package_example_inputs: bool = False, + ) -> None: + options: dict[str, typing.Any] = { "aot_inductor.package": True, "aot_inductor.package_cpp_only": True, "always_keep_tensor_constants": True, "aot_inductor.package_constants_in_so": False, + "aot_inductor.compile_standalone": standalone, } - weights_map = {} + aoti_files_map = {} + model_names = [] for name, ep in self._method_overloads: - weights = torch._inductor.aot_compile(ep.module(), (), options=options) # type: ignore[arg-type] - weights_map[name] = weights - torch._inductor.package.package.package_aoti( + name = name.replace(":", "__") + model_names.append(name) + options["aot_inductor.model_name_for_generated_files"] = name + aoti_files = torch._inductor.aot_compile( + ep.module(), # type: ignore[arg-type] + ep.example_inputs[0], + kwargs=ep.example_inputs[1], + options=options, + ) + aoti_files_map[name] = aoti_files + + from torch._inductor.package import package + + pt2_path = package.package_aoti( f, - weights_map, # type: ignore[arg-type] + aoti_files_map, # type: ignore[arg-type] + ) + + if not standalone: + return + + assert isinstance(pt2_path, str) + base_directory = os.path.dirname(pt2_path) + package_name = os.path.basename(pt2_path)[:-4] + with ( + zipfile.ZipFile(pt2_path, "r") as zip_ref, + ): + zip_ref.extractall(base_directory) + + example_inputs_map: typing.Optional[dict[str, int]] = ( + {} if package_example_inputs else None + ) + use_cuda = False + for name, ep in self._method_overloads: + name = name.replace(":", "__") + # TODO: also dump kwargs + # TODO: currently only support list of Tensors and they need to be on the same device + if not ep.example_inputs: + continue + for inp in ep.example_inputs[0]: + if isinstance(inp, torch.Tensor) and inp.device.type == "cuda": + # TODO: more carefully determine the device type + use_cuda = True + if package_example_inputs: + assert example_inputs_map is not None + example_inputs_map[name] = len(ep.example_inputs[0]) + for i, t in enumerate(ep.example_inputs[0]): + path = Path(base_directory) / f"{name}_input_{i}.pt" + torch.save(t, path) + + cmake_file_str = _get_make_file(package_name, model_names, use_cuda) + + with open(Path(base_directory) / "CMakeLists.txt", "w") as file: + file.write(cmake_file_str) + + main_file_str = _get_main_cpp_file( + package_name, model_names, use_cuda, example_inputs_map ) + with open(Path(base_directory) / "main.cpp", "w") as file: + file.write(main_file_str) diff --git a/torch/export/experimental/_utils.py b/torch/export/experimental/_utils.py new file mode 100644 index 000000000000..b91dfbb0db80 --- /dev/null +++ b/torch/export/experimental/_utils.py @@ -0,0 +1,206 @@ +import typing + +from torch._inductor.utils import IndentedBuffer + + +__all__ = [] # type: ignore[var-annotated] + + +def _get_main_cpp_file( + package_name: str, + model_names: list[str], + cuda: bool, + example_inputs_map: typing.Optional[dict[str, int]], +) -> str: + """ + Generates a main.cpp file for AOTInductor standalone models in the specified package. + + Args: + package_name (str): Name of the package containing the models. + model_names (List[str]): List of model names to include in the generated main.cpp. + cuda (bool): Whether to generate code with CUDA support. + example_inputs_map (Optional[Dict[str, List[Tensor]]]): A mapping from model name to + its list of example input tensors. If provided, the generated main.cpp will + load and run these inputs. + + Returns: + str: The contents of the generated main.cpp file as a string. + """ + + ib = IndentedBuffer() + + ib.writelines( + [ + "#include ", + "#include ", + "#include ", + "#include ", + "#include ", + "#include ", + "#include ", + ] + ) + if cuda: + ib.writelines( + [ + "#include ", + "#include ", + ] + ) + + for model_name in model_names: + ib.writeline( + f'#include "{package_name}/data/aotinductor/{model_name}/{model_name}.h"' + ) + + ib.newline() + for model_name in model_names: + ib.writeline(f"using torch::aot_inductor::AOTInductorModel{model_name};") + + ib.writelines( + [ + "using torch::aot_inductor::ConstantHandle;", + "using torch::aot_inductor::ConstantMap;", + "", + "int main(int argc, char* argv[]) {", + ] + ) + + with ib.indent(): + ib.writeline(f'std::string device_str = "{"cuda" if cuda else "cpu"}";') + ib.writeline("try {") + + with ib.indent(): + ib.writeline("c10::Device device(device_str);") + + if example_inputs_map is not None: + # TODO: add device + for i, model_name in enumerate(model_names): + num_inputs = example_inputs_map[model_name] + + ib.writeline(f"// Load input tensors for model {model_name}") + ib.writeline(f"std::vector input_tensors{i + 1};") + ib.writeline(f"for (int j = 0; j < {num_inputs}; ++j) {{") + with ib.indent(): + ib.writeline( + f'std::string filename = "{model_name}_input_" + std::to_string(j) + ".pt";' + ) + ib.writeline("std::ifstream in(filename, std::ios::binary);") + ib.writeline("if (!in.is_open()) {") + with ib.indent(): + ib.writeline( + 'std::cerr << "Failed to open file: " << filename << std::endl;' + ) + ib.writeline("return 1;") + ib.writeline("}") + ib.writeline( + "std::vector buffer((std::istreambuf_iterator(in)), std::istreambuf_iterator());" + ) + ib.writeline( + "torch::IValue ivalue = torch::pickle_load(buffer);" + ) + ib.writeline( + f"input_tensors{i + 1}.push_back(ivalue.toTensor().to(device));" + ) + ib.writeline("}") + ib.newline() + + ib.newline() + ib.writeline("\n// Create array of input handles") + for i in range(len(model_names)): + ib.writelines( + [ + f"auto input_handles{i + 1} =", + f" torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(input_tensors{i + 1});", + ] + ) + + ib.writeline("\n// Create array for output handles") + for i in range(len(model_names)): + ib.writeline(f"AtenTensorHandle output_handle{i + 1};") + + ib.writeline("\n// Create and load models") + for i, model_name in enumerate(model_names): + ib.writelines( + [ + f"auto constants_map{i + 1} = std::make_shared();", + f"auto constants_array{i + 1} = std::make_shared>();", + f"auto model{i + 1} = AOTInductorModel{model_name}::Create(", + f" constants_map{i + 1}, constants_array{i + 1}, device_str,", + f' "{package_name}/data/aotinductor/{model_name}/");', + f"model{i + 1}->load_constants();", + ] + ) + + if example_inputs_map is not None: + ib.writeline("\n// Run the models") + for i in range(len(model_names)): + ib.writeline( + f"torch::aot_inductor::DeviceStreamType stream{i + 1} = nullptr;" + ) + ib.writeline( + f"model{i + 1}->run(&input_handles{i + 1}[0], &output_handle{i + 1}, stream{i + 1}, nullptr);" + ) + + ib.writeline("\n// Convert output handles to tensors") + for i in range(len(model_names)): + ib.writelines( + [ + f"auto output_tensor{i + 1} =", + f" torch::aot_inductor::alloc_tensors_by_stealing_from_handles(&output_handle{i + 1}, 1);", + ] + ) + + ib.writeline("\n// Validate outputs") + for i in range(len(model_names)): + ib.writeline( + f"""std::cout << "output_tensor{i + 1}" << output_tensor{i + 1} << std::endl;""" + ) + + ib.writeline("return 0;") + + ib.writelines( + [ + "} catch (const std::exception &e) {", + ] + ) + with ib.indent(): + ib.writeline('std::cerr << "Error: " << e.what() << std::endl;') + ib.writeline("return 1;") + + ib.writeline("}") + ib.writeline("}") + + return ib.getvalue() + + +def _get_make_file(package_name: str, model_names: list[str], cuda: bool) -> str: + ib = IndentedBuffer() + + ib.writelines( + [ + "cmake_minimum_required(VERSION 3.10)", + "project(TestProject)", + "", + "set(CMAKE_CXX_STANDARD 17)", + "", + "find_package(Torch REQUIRED)", + ] + ) + if cuda: + ib.writeline("find_package(CUDA REQUIRED)") + + ib.newline() + for model_name in model_names: + ib.writeline(f"add_subdirectory({package_name}/data/aotinductor/{model_name}/)") + + ib.writeline("\nadd_executable(main main.cpp)") + if cuda: + ib.writeline("target_compile_definitions(main PRIVATE USE_CUDA)") + + model_libs = " ".join(model_names) + ib.writeline(f"target_link_libraries(main PRIVATE torch {model_libs})") + if cuda: + ib.writeline("target_link_libraries(main PRIVATE cuda ${CUDA_LIBRARIES})") + + return ib.getvalue() diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index bbfb9202c560..4aee86b099e1 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -7,10 +7,10 @@ import operator import types import warnings -from collections import defaultdict, namedtuple +from collections import defaultdict from collections.abc import Iterator from contextlib import contextmanager -from typing import Any, Callable, final, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, final, NamedTuple, Optional, TYPE_CHECKING, Union from torch._guards import tracing, TracingContext from torch._higher_order_ops.utils import autograd_not_implemented @@ -40,6 +40,7 @@ import torch import torch.utils._pytree as pytree from torch._export.utils import ( + _build_cache, _collect_all_valid_cia_ops, _collect_and_set_constant_attrs, _collect_param_buffer_metadata, @@ -325,7 +326,7 @@ def default_decompositions() -> "CustomDecompTable": def _decompose_and_get_gm_with_new_signature_constants( - ep, + ep: "ExportedProgram", *, cia_to_decomp: dict[torch._ops.OperatorBase, Callable], python_decomp_table: dict[torch._ops.OperatorBase, Callable], @@ -384,9 +385,11 @@ def _is_joint_ir_decomp(ep, joint_loss_index): # Fix the graph output signature to be tuple if scalar out_spec = mod._out_spec + assert isinstance(mod.graph._codegen, _PyTreeCodeGen) orig_arg_names = mod.graph._codegen.pytree_info.orig_args # aot_export expect the return type to always be a tuple. + assert out_spec is not None if out_spec.type not in (list, tuple): out_spec = pytree.TreeSpec(tuple, None, [out_spec]) @@ -610,7 +613,7 @@ def update_arg(old_arg, new_ph): raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}") new_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] - new_outputs = list(gm.graph.nodes)[-1].args[0] + new_outputs: tuple[torch.fx.Node, ...] = tuple(gm.graph.output_node().args[0]) # type: ignore[arg-type] # rename the placeholders assert len(new_placeholders) == len(old_placeholders) @@ -618,11 +621,18 @@ def update_arg(old_arg, new_ph): new_ph.name = new_ph.target = old_ph.name # handle name collisions with newly decomposed graph nodes - name_map = {ph.name: ph.name for ph in new_placeholders} + name_map = {} + find_available: dict[str, int] = defaultdict(int) + used_names: set[str] = set() + for ph in new_placeholders: + name_map[ph.name] = ph.name + _build_cache(ph.name, find_available, used_names) for node in gm.graph.nodes: if node.op == "placeholder": continue - node.name = _rename_without_collisions(name_map, node.name, node.name) + node.name = _rename_without_collisions( + name_map, find_available, used_names, node.name, node.name + ) # propagate names to higher order op subgraphs _name_hoo_subgraph_placeholders(gm) @@ -654,9 +664,9 @@ def update_arg(old_arg, new_ph): # update output specs gm.recompile() - for i, name in enumerate(_graph_output_names(gm)): - if isinstance(new_outputs[i], torch.fx.Node): - new_outputs[i].name = name + for output, name in zip(new_outputs, _graph_output_names(gm)): + if name is not None: + output.name = name # To match the output target with correct input for input mutations # need to find the old to new placeholder map @@ -727,7 +737,7 @@ def update_arg(old_arg, new_ph): for i, spec in enumerate(ep.graph_signature.input_specs) if isinstance(spec.arg, TensorArgument) } - for i, node in enumerate(new_outputs[len(output_specs) :]): + for node in new_outputs[len(output_specs) :]: source = gradients[node.name] spec = specs[source] # type: ignore[index] if spec.kind == InputKind.PARAMETER: @@ -1208,7 +1218,9 @@ def example_inputs(self, value): @property @compatibility(is_backward_compatible=False) def call_spec(self): - CallSpec = namedtuple("CallSpec", ["in_spec", "out_spec"]) + class CallSpec(NamedTuple): + in_spec: Optional[pytree.TreeSpec] + out_spec: Optional[pytree.TreeSpec] if len(self.module_call_graph) == 0: return CallSpec(in_spec=None, out_spec=None) @@ -1364,7 +1376,7 @@ def __str__(self) -> str: ) return string - def module(self) -> torch.nn.Module: + def module(self) -> torch.fx.GraphModule: """ Returns a self contained GraphModule with all the parameters/buffers inlined. """ diff --git a/torch/export/pt2_archive/_package.py b/torch/export/pt2_archive/_package.py index 7c97e6abe171..9d3be9758a7c 100644 --- a/torch/export/pt2_archive/_package.py +++ b/torch/export/pt2_archive/_package.py @@ -12,8 +12,8 @@ import torch import torch.utils._pytree as pytree from torch._export.serde.serialize import deserialize, serialize, SerializedArtifact +from torch.export import ExportedProgram from torch.export._tree_utils import reorder_kwargs -from torch.export.exported_program import ExportedProgram from torch.export.pt2_archive._package_weights import ( get_complete, group_weights, @@ -308,7 +308,7 @@ def _package_exported_programs( return if isinstance(exported_programs, ExportedProgram): - exported_programs = {"model", exported_programs} # type: ignore[assignment] + exported_programs = {"model": exported_programs} assert isinstance(exported_programs, dict) @@ -350,22 +350,21 @@ def package_pt2( opset_version: Optional[dict[str, int]] = None, pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, ) -> FileLike: - """ - Saves the artifacts to a PT2Archive format - (https://docs.google.com/document/d/1RQ4cmywilnFUT1VE-4oTGxwXdc8vowCSZsrRgo3wFA8/edit?tab=t.0#heading=h.v2y2jgnwc56a). - The artifact can then be loaded using ``load_pt2``. + r""" + Saves the artifacts to a PT2Archive format. The artifact can then be loaded + using ``load_pt2``. Args: - f (str | os.PathLike[str] | IO[bytes]) A file-like object (has to + f (str | os.PathLike[str] | IO[bytes]): A file-like object (has to implement write and flush) or a string containing a file name. exported_programs (Union[ExportedProgram, dict[str, ExportedProgram]]): The exported program to save, or a dictionary mapping model name to an exported program to save. The exported program will be saved under - models/*.json. If only one ExportedProgram is specified, this will + models/\*.json. If only one ExportedProgram is specified, this will automatically be named "model". - aoti_files (Union[list[str], dict[str, list[str]]): A list of files + aoti_files (Union[list[str], dict[str, list[str]]]): A list of files generated by AOTInductor via ``torch._inductor.aot_compile(..., {"aot_inductor.package": True})``, or a dictionary mapping model name to its AOTInductor generated files. @@ -563,6 +562,8 @@ def load_pt2( A ``PT2ArchiveContents`` object which contains all the objects in the PT2. """ + from torch._inductor.cpp_builder import normalize_path_separator + if not ( (isinstance(f, (io.IOBase, IO)) and f.readable() and f.seekable()) or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2")) @@ -601,6 +602,9 @@ def load_pt2( file_end = file[ len(AOTINDUCTOR_DIR) : ] # remove data/aotinductor/ prefix + file_end = normalize_path_separator( + file_end + ) # Win32 need normalize path before split. model_name = file_end.split("/")[ 0 ] # split "model_name/...cpp" into "model_name" diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 1a3cf9610a6d..3a741778d0d4 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -15,10 +15,10 @@ import torch.fx._pytree as fx_pytree import torch.utils._pytree as pytree from torch._library.fake_class_registry import FakeScriptObject +from torch.export import ExportedProgram from torch.export._tree_utils import reorder_kwargs from torch.export.exported_program import ( ConstantArgument, - ExportedProgram, ExportGraphSignature, InputKind, ModuleCallSignature, diff --git a/torch/fx/README.md b/torch/fx/README.md index 4c799da7bc40..3d42cb9375d4 100644 --- a/torch/fx/README.md +++ b/torch/fx/README.md @@ -70,7 +70,7 @@ Here, we set up a simple Module that exercises different language features: fetc The `fx.Graph` is a core data structure in FX that represents the operations and their dependencies in a structured format. It consists of a List of `fx.Node` representing individual operations and their inputs and outputs. The Graph enables simple manipulation and analysis of the model structure, which is essential for implementing various transformations and optimizations. ## Node -An `fx.Node` is a datastructure that represent individual operations within an `fx.Graph`, it maps to callsites such as operators, methods and modules. Each `fx.Node` keeps track of its inputs, the previous and next nodes, the stacktrace so you can map back the node to a line of code in your python file and some optional metadata stored in a `meta` dict. +An `fx.Node` is a data structure that represents individual operations within an `fx.Graph`, it maps to callsites such as operators, methods and modules. Each `fx.Node` keeps track of its inputs, the previous and next nodes, the stacktrace so you can map back the node to a line of code in your python file and some optional metadata stored in a `meta` dict. ## [GraphModule](https://pytorch.org/docs/main/fx.html#torch.fx.GraphModule) ## The `fx.GraphModule` is a subclass of `nn.Module` that holds the transformed Graph, the original module's parameter attributes and its source code. It serves as the primary output of FX transformations and can be used like any other `nn.Module`. `fx.GraphModule` allows for the execution of the transformed model, as it generates a valid forward method based on the Graph's structure. @@ -115,11 +115,11 @@ Tracing captures an intermediate representation (IR), which is represented as a Node is the data structure that represents individual operations within a Graph. For the most part, Nodes represent callsites to various entities, such as operators, methods, and Modules (some exceptions include Nodes that specify function inputs and outputs). Each Node has a function specified by its `op` property. The Node semantics for each value of `op` are as follows: -- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` holds either: 1) nothing, or 2) a single argument denoting the default parameter of the function input. `kwargs` is don't-care. Placeholders correspond to the function parameters (e.g. `x`) in the graph printout. -- `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are don't-care +- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on. `target` is similarly the name of the argument. `args` holds either: 1) nothing, or 2) a single argument denoting the default parameter of the function input. `kwargs` is ignored. Placeholders correspond to the function parameters (e.g. `x`) in the graph printout. +- `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy. `args` and `kwargs` are ignored - `call_function` applies a free function to some values. `name` is similarly the name of the value to assign to. `target` is the function to be applied. `args` and `kwargs` represent the arguments to the function, following the Python calling convention - `call_module` applies a module in the module hierarchy's `forward()` method to given arguments. `name` is as previous. `target` is the fully-qualified name of the module in the module hierarchy to call. `args` and `kwargs` represent the arguments to invoke the module on, *including the self argument*. -- `call_method` calls a method on a value. `name` is as similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, *including the self argument* +- `call_method` calls a method on a value. `name` is similar. `target` is the string name of the method to apply to the `self` argument. `args` and `kwargs` represent the arguments to invoke the module on, *including the self argument* - `output` contains the output of the traced function in its `args[0]` attribute. This corresponds to the "return" statement in the Graph printout. To facilitate easier analysis of data dependencies, Nodes have read-only properties `input_nodes` and `users`, which specify which Nodes in the Graph are used by this Node and which Nodes use this Node, respectively. Although Nodes are represented as a doubly-linked list, the use-def relationships form an acyclic graph and can be traversed as such. diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 6777b1f31cef..a578723ea1cb 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -2343,7 +2343,7 @@ def make_fx( record_module_stack, _allow_fake_constant, _error_on_data_dependent_ops, - record_stack_traces=record_stack_traces or config.trace.enabled, + record_stack_traces=record_stack_traces or config.trace.provenance_tracking, ) @functools.wraps(f) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index e38e5f777d66..e2e91624db95 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -6263,7 +6263,7 @@ def resimplify_floor_div(axioms: dict[sympy.Expr, sympy.Expr]) -> None: return self._resimplify_floor_div_axioms = False new_items = {} - for k, v in axioms.items(): + for k, v in list(axioms.items()): # A FloorDiv in implications could have became CleanDiv at this point, due to new facts # to the shapeEnv. This handles such issue but its not ideal. This is the only expression # simplification that depends on the global state of shape env. diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index e2d2f9d7466d..4e1ab646593a 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -5,6 +5,7 @@ import torch import torch.fx.traceback as fx_traceback +from torch._logging import trace_structured from torch.hub import tqdm from . import config @@ -175,13 +176,26 @@ def run( if self.extra_traceback: msg = f"While executing {node.format_node()}" msg = f"{e.args[0]}\n\n{msg}" if e.args else str(msg) + msg += f"\nOriginal traceback:\n{node.stack_trace}" if ( isinstance(self.module, GraphModule) and self.module.graph is not None and isinstance(self.module.graph, torch.fx.Graph) ): - msg += f"\nGraphModule: {self.module.print_readable(print_output=False, include_stride=True)}\n" - msg += f"\nOriginal traceback:\n{node.stack_trace}" + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_interpreter_error", + "encoding": "string", + }, + payload_fn=lambda: ( + f"{msg}\nGraphModule: " + f"{self.module.print_readable(print_output=False, include_stride=True)}" # type: ignore[operator] + ), + ) + + msg += "\nUse tlparse to see full graph. " + msg += "(https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)" e.args = (msg,) + e.args[1:] if isinstance(e, KeyError): raise RuntimeError(*e.args) from e diff --git a/torch/fx/passes/graph_transform_observer.py b/torch/fx/passes/graph_transform_observer.py index 8e59fc7ae179..17929bb63787 100644 --- a/torch/fx/passes/graph_transform_observer.py +++ b/torch/fx/passes/graph_transform_observer.py @@ -31,18 +31,20 @@ def __init__( """ log_url is inferred to be torch._inductor.config.trace.log_url_for_graph_xform unless otherwise specified """ - from torch._inductor.config import trace + from torch._inductor import config as inductor_config self.gm = gm self.passname = passname self.subsystem = subsystem if log_url is None: - log_url = trace.log_url_for_graph_xform + log_url = inductor_config.trace.log_url_for_graph_xform self.log_url = log_url - self.active = trace.enabled or self.log_url is not None + self.active = ( + self.log_url is not None or inductor_config.trace.provenance_tracking + ) if self.active: self.erased_nodes: set[str] = set() diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index 3ec156005a01..648a80b87b68 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -51,6 +51,8 @@ def __init__(self, name: str, target: str, graph_id: int): action: list["NodeSourceAction"] from_node: list["NodeSource"] node_info: Optional["NodeInfo"] + _dict: Optional[dict[str, Any]] + _action_string: Optional[str] def __init__( self, @@ -80,6 +82,10 @@ def __init__( self.node_info = None self.from_node = [] + # cache the action string and dict representation for performance. + self._action_string: Optional[str] = None + self._dict: Optional[dict[str, Any]] = None + @property def name(self) -> str: return self.node_info.name if self.node_info else "" @@ -96,7 +102,9 @@ def __repr__(self): return self.print_readable() def _get_action_string(self): - return "+".join([a.name.lower() for a in self.action]) + if self._action_string is None: + self._action_string = "+".join([a.name.lower() for a in self.action]) + return self._action_string def print_readable(self, indent=0): if indent > 9: @@ -112,16 +120,92 @@ def print_readable(self, indent=0): return result def to_dict(self) -> dict: - # Convert the object to a dictionary - action_string = self._get_action_string() - return { - "name": self.name, - "target": self.target, - "graph_id": self.graph_id, - "pass_name": self.pass_name, - "action": action_string, - "from_node": [node.to_dict() for node in self.from_node], - } + if self._dict is None: + # Convert the object to a dictionary + action_string = self._get_action_string() + self._dict = { + "name": self.name, + "target": self.target, + "graph_id": self.graph_id, + "pass_name": self.pass_name, + "action": action_string, + "from_node": [node.to_dict() for node in self.from_node], + } + + assert self._dict is not None + return self._dict + + def __eq__(self, other: object): + if not isinstance(other, NodeSource): + return False + return self.to_dict() == other.to_dict() + + def __hash__(self): + # Create a hash based on the dictionary representation + # We need to convert the dict to a hashable form + def _make_hashable(obj): + if isinstance(obj, dict): + return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items())) + elif isinstance(obj, list): + return tuple(_make_hashable(item) for item in obj) + else: + return obj + + return hash(_make_hashable(self.to_dict())) + + @classmethod + def _from_dict(cls, d: Optional[dict]) -> Optional["NodeSource"]: + """ + Recursively deserialize from_node metadata from dictionary data. + It is used to deserialize the from_node field from serialized metadata. + Please use constructor NodeSource(node, ...) to create a NodeSource object. + """ + if d is None: + return None + + assert isinstance(d, dict), f"Expected a dict, got {type(d)}" + + # Create a NodeSource object directly without going through the constructor + # to avoid issues with graph ID and node creation + node_source = NodeSource.__new__(NodeSource) + + # Reset the cached properties + node_source._action_string = None + node_source._dict = None + + # Set the basic attributes + node_source.pass_name = d.get("pass_name", "") + + # Parse action string back to NodeSourceAction enum list + action_str = d.get("action", "") + actions = [] + if action_str: + for action_name in action_str.split("+"): + if action_name.upper() == "CREATE": + actions.append(NodeSourceAction.CREATE) + elif action_name.upper() == "REPLACE": + actions.append(NodeSourceAction.REPLACE) + node_source.action = actions + + # Create the NodeInfo object directly + if "name" in d and "target" in d and "graph_id" in d: + node_info = NodeSource.NodeInfo( + d.get("name", ""), d.get("target", ""), d.get("graph_id", -1) + ) + node_source.node_info = node_info + else: + node_source.node_info = None + + # Recursively deserialize nested from_node + if d.get("from_node", None) is not None: + node_source.from_node = [ + result + for fn in d.get("from_node", []) + if (result := cls._from_dict(fn)) is not None + ] + else: + node_source.from_node = [] + return node_source @compatibility(is_backward_compatible=False) diff --git a/torch/header_only_apis.txt b/torch/header_only_apis.txt index 49b2784e1df1..32c2d308d9d2 100644 --- a/torch/header_only_apis.txt +++ b/torch/header_only_apis.txt @@ -50,3 +50,6 @@ size # torch/headeronly/macros/Export.h C10_API + +# torch/headeronly/util/Exception.h +STD_TORCH_CHECK diff --git a/torch/headeronly/BUCK.oss b/torch/headeronly/BUCK.oss new file mode 100644 index 000000000000..2b8d77e597a6 --- /dev/null +++ b/torch/headeronly/BUCK.oss @@ -0,0 +1,26 @@ +load("//tools/build_defs:glob_defs.bzl", "subdir_glob") + +cxx_library( + name = "torch_headeronly", + header_namespace = "torch/headeronly", + exported_deps = [], + compiler_flags = [ + "-Werror", + "-Wno-global-constructors", + ], + exported_headers = subdir_glob( + [ + ("", "**/*.h"), + ], + ), + exported_linker_flags = [], + exported_preprocessor_flags = [ + '-DC10_USING_CUSTOM_GENERATED_MACROS', + '-DC10_USE_GLOG', + ], + link_whole = True, + platform_preprocessor_flags = [['windows', ['-D_WINDOWS', '-D_WIN32', '-DWIN32', '-DNOMINMAX', '-D_CRT_SECURE_NO_WARNINGS', '-D_USE_MATH_DEFINES']], ['windows.*64$', ['-D_WIN64']]], + preprocessor_flags = ['-DC10_BUILD_MAIN_LIB'], + reexport_all_header_dependencies = True, + visibility = ['PUBLIC'], +) diff --git a/torch/headeronly/BUILD.bazel b/torch/headeronly/BUILD.bazel index f4a27fac1f7f..030651b12043 100644 --- a/torch/headeronly/BUILD.bazel +++ b/torch/headeronly/BUILD.bazel @@ -1,9 +1,5 @@ load("@rules_cc//cc:defs.bzl", "cc_library") +load("//:tools/bazel.bzl", "rules") +load(":build.bzl", "define_targets") -cc_library( - name = "torch_headeronly", - hdrs = glob([ - "**/*.h" - ]), - visibility = ["//visibility:public"], -) +define_targets(rules = rules) diff --git a/torch/headeronly/CMakeLists.txt b/torch/headeronly/CMakeLists.txt new file mode 100644 index 000000000000..e42981d8804e --- /dev/null +++ b/torch/headeronly/CMakeLists.txt @@ -0,0 +1,35 @@ +cmake_minimum_required(VERSION 3.27 FATAL_ERROR) + +project(headeronly CXX) + +set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.") +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# Main build file for torch/headeronly, except there's no build cuz this lib is header-only! + +# ---[ Configure macro file. +set(C10_USE_GFLAGS ${USE_GFLAGS}) # used in cmake_macros.h.in +set(C10_USE_GLOG ${USE_GLOG}) # used in cmake_macros.h.in +set(C10_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) # used in cmake_macros.h.in +set(C10_USE_NUMA ${USE_NUMA}) # used in cmake_macros.h.in +set(C10_USE_MSVC_STATIC_RUNTIME ${CAFFE2_USE_MSVC_STATIC_RUNTIME}) # used in cmake_macros.h.in +set(C10_USE_ROCM_KERNEL_ASSERT ${USE_ROCM_KERNEL_ASSERT}) # used in cmake_macros.h.in +configure_file( + ${CMAKE_CURRENT_LIST_DIR}/macros/cmake_macros.h.in + ${CMAKE_BINARY_DIR}/torch/headeronly/macros/cmake_macros.h) + +file(GLOB HEADERONLY_HEADERS + *.h + macros/*.h + util/*.h +) + +add_library(headeronly INTERFACE ${HEADERONLY_HEADERS}) + +install(FILES ${CMAKE_BINARY_DIR}/torch/headeronly/macros/cmake_macros.h + DESTINATION include/torch/headeronly/macros) + +if(NOT BUILD_LIBTORCHLESS) + # ---[ Installation copied from c10/CMakeLists.txt + install(TARGETS headeronly EXPORT Caffe2Targets DESTINATION lib) +endif() diff --git a/torch/headeronly/build.bzl b/torch/headeronly/build.bzl new file mode 100644 index 000000000000..6ec9a843e884 --- /dev/null +++ b/torch/headeronly/build.bzl @@ -0,0 +1,11 @@ +def define_targets(rules): + rules.cc_library( + name = "torch_headeronly", + hdrs = rules.glob([ + "**/*.h" + ]), + visibility = ["//visibility:public"], + deps = [ + "//torch/headeronly/macros", + ], + ) diff --git a/torch/headeronly/macros/BUILD.bazel b/torch/headeronly/macros/BUILD.bazel new file mode 100644 index 000000000000..d1a0db360d23 --- /dev/null +++ b/torch/headeronly/macros/BUILD.bazel @@ -0,0 +1,4 @@ +load("//:tools/bazel.bzl", "rules") +load(":build.bzl", "define_targets") + +define_targets(rules = rules) diff --git a/torch/headeronly/macros/Export.h b/torch/headeronly/macros/Export.h index 183aeab56344..8dd25419efb4 100644 --- a/torch/headeronly/macros/Export.h +++ b/torch/headeronly/macros/Export.h @@ -1,5 +1,12 @@ #pragma once +#ifndef C10_MACROS_EXPORT_H_ +#define C10_MACROS_EXPORT_H_ + +#ifndef C10_USING_CUSTOM_GENERATED_MACROS +#include +#endif // C10_USING_CUSTOM_GENERATED_MACROS + /* Header file to define the common scaffolding for exported symbols. * * Export is by itself a quite tricky situation to deal with, and if you are @@ -85,3 +92,62 @@ #else #define C10_API C10_IMPORT #endif + +// This one is being used by libtorch.so +#ifdef CAFFE2_BUILD_MAIN_LIB +#define TORCH_API C10_EXPORT +#else +#define TORCH_API C10_IMPORT +#endif + +// You may be wondering why we have TORCH_CUDA_CPP_API and TORCH_CUDA_CU_API +// belonging to the same library instead of just one TORCH_CUDA_API. Well, it +// can indeed just be one TORCH_CUDA_API (and used to be)! TORCH_CUDA_CPP_API +// and TORCH_CUDA_CU_API are artifacts of when we needed a split build to +// avoid relocation marker linking errors. The context is as follows: +// +// Once upon a time, there _was_ only TORCH_CUDA_API. All was happy until we +// tried to compile PyTorch for CUDA 11.1, which ran into relocation marker +// issues when linking big binaries. +// (https://github.com/pytorch/pytorch/issues/39968) We had two choices: +// (1) Stop supporting so many GPU architectures +// (2) Do something else +// We chose #2 and decided to split the behemoth that was torch_cuda into two +// smaller libraries, one with most of the core kernel functions (torch_cuda_cu) +// and the other that had..well..everything else (torch_cuda_cpp). The idea was +// this: instead of linking our static libraries (like the hefty +// libcudnn_static.a) with another huge library, torch_cuda, and run into pesky +// relocation marker issues, we could link our static libraries to a smaller +// part of torch_cuda (torch_cuda_cpp) and avoid the issues. + +// libtorch_cuda.so (where torch_cuda_cu and torch_cuda_cpp are a part of the +// same api) +#ifdef TORCH_CUDA_BUILD_MAIN_LIB +#define TORCH_CUDA_CPP_API C10_EXPORT +#define TORCH_CUDA_CU_API C10_EXPORT +#else +#define TORCH_CUDA_CPP_API C10_IMPORT +#define TORCH_CUDA_CU_API C10_IMPORT +#endif + +#if defined(TORCH_HIP_BUILD_MAIN_LIB) +#define TORCH_HIP_CPP_API C10_EXPORT +#define TORCH_HIP_API C10_EXPORT +#else +#define TORCH_HIP_CPP_API C10_IMPORT +#define TORCH_HIP_API C10_IMPORT +#endif + +#if defined(TORCH_XPU_BUILD_MAIN_LIB) +#define TORCH_XPU_API C10_EXPORT +#else +#define TORCH_XPU_API C10_IMPORT +#endif + +// Enums only need to be exported on windows for non-CUDA files +#if defined(_WIN32) && defined(__CUDACC__) +#define C10_API_ENUM C10_API +#else +#define C10_API_ENUM +#endif +#endif // C10_MACROS_EXPORT_H_ diff --git a/torch/headeronly/macros/Macros.h b/torch/headeronly/macros/Macros.h new file mode 100644 index 000000000000..1e07ab0446e8 --- /dev/null +++ b/torch/headeronly/macros/Macros.h @@ -0,0 +1,564 @@ +#ifndef C10_MACROS_MACROS_H_ +#define C10_MACROS_MACROS_H_ +#include + +/* Main entry for torch/headeronly/macros (used to be c10/macros). + * + * In your code, include torch/headeronly/macros/Macros.h directly, instead of + * individual files in this folder. + */ + +// For build systems that do not directly depend on CMake and directly build +// from the source directory (such as Buck), one may not have a cmake_macros.h +// file at all. In this case, the build system is responsible for providing +// correct macro definitions corresponding to the cmake_macros.h.in file. +// +// In such scenarios, one should define the macro +// C10_USING_CUSTOM_GENERATED_MACROS +// to inform this header that it does not need to include the cmake_macros.h +// file. + +#ifndef C10_USING_CUSTOM_GENERATED_MACROS +#include +#endif // C10_USING_CUSTOM_GENERATED_MACROS + +#include + +#if defined(__clang__) +#define __ubsan_ignore_float_divide_by_zero__ \ + __attribute__((no_sanitize("float-divide-by-zero"))) +#define __ubsan_ignore_undefined__ __attribute__((no_sanitize("undefined"))) +#define __ubsan_ignore_signed_int_overflow__ \ + __attribute__((no_sanitize("signed-integer-overflow"))) +#define __ubsan_ignore_pointer_overflow__ \ + __attribute__((no_sanitize("pointer-overflow"))) +#define __ubsan_ignore_function__ __attribute__((no_sanitize("function"))) +#define __ubsan_ignore_float_cast_overflow__ \ + __attribute__((no_sanitize("float-cast-overflow"))) +#else +#define __ubsan_ignore_float_divide_by_zero__ +#define __ubsan_ignore_undefined__ +#define __ubsan_ignore_signed_int_overflow__ +#define __ubsan_ignore_pointer_overflow__ +#define __ubsan_ignore_function__ +#define __ubsan_ignore_float_cast_overflow__ +#endif + +// Detect address sanitizer as some stuff doesn't work with it +#undef C10_ASAN_ENABLED + +// for clang +#if defined(__has_feature) +#if ((__has_feature(address_sanitizer))) +#define C10_ASAN_ENABLED 1 +#endif +#endif + +// for gcc +#if defined(__SANITIZE_ADDRESS__) +#if __SANITIZE_ADDRESS__ +#if !defined(C10_ASAN_ENABLED) +#define C10_ASAN_ENABLED 1 +#endif +#endif +#endif + +#if !defined(C10_ASAN_ENABLED) +#define C10_ASAN_ENABLED 0 +#endif + +// Detect undefined-behavior sanitizer (UBSAN) +#undef C10_UBSAN_ENABLED + +// for clang or gcc >= 14 +// NB: gcc 14 adds support for Clang's __has_feature +// https://gcc.gnu.org/gcc-14/changes.html +// gcc < 14 doesn't have a macro for UBSAN +// (e.g. __SANITIZE_UNDEFINED__ does not exist in gcc) +// https://github.com/google/sanitizers/issues/765 +#if defined(__has_feature) +#if ((__has_feature(undefined_behavior_sanitizer))) +#define C10_UBSAN_ENABLED 1 +#endif +#endif + +#if !defined(C10_UBSAN_ENABLED) +#define C10_UBSAN_ENABLED 0 +#endif + +// Disable the copy and assignment operator for a class. Note that this will +// disable the usage of the class in std containers. +#define C10_DISABLE_COPY_AND_ASSIGN(classname) \ + classname(const classname&) = delete; \ + classname& operator=(const classname&) = delete + +#define C10_CONCATENATE_IMPL(s1, s2) s1##s2 +#define C10_CONCATENATE(s1, s2) C10_CONCATENATE_IMPL(s1, s2) + +#define C10_MACRO_EXPAND(args) args + +#define C10_STRINGIZE_IMPL(x) #x +#define C10_STRINGIZE(x) C10_STRINGIZE_IMPL(x) + +/** + * C10_ANONYMOUS_VARIABLE(str) introduces a new identifier which starts with + * str and ends with a unique number. + */ +#ifdef __COUNTER__ +#define C10_UID __COUNTER__ +#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __COUNTER__) +#else +#define C10_UID __LINE__ +#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __LINE__) +#endif + +#ifdef __has_cpp_attribute +#define C10_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x) +#else +#define C10_HAS_CPP_ATTRIBUTE(x) (0) +#endif + +#ifndef FBCODE_CAFFE2 +/// DEPRECATED: Warn if a type or return value is discarded. +#define C10_NODISCARD [[nodiscard]] + +/// DEPRECATED: Suppress an unused variable. +#define C10_UNUSED [[maybe_unused]] +#endif + +#if !defined(__has_attribute) +#define __has_attribute(x) 0 +#endif + +// Direct port of LLVM_ATTRIBUTE_USED. +#if __has_attribute(used) +#define C10_USED __attribute__((__used__)) +#else +#define C10_USED +#endif + +#define C10_RESTRICT __restrict + +// Simply define the namespace, in case a dependent library want to refer to +// the c10 namespace but not any nontrivial files. +namespace c10 {} +namespace c10::cuda {} +namespace c10::hip {} +namespace c10::xpu {} + +// Since C10 is the core library for caffe2 (and aten), we will simply reroute +// all abstractions defined in c10 to be available in caffe2 as well. +// This is only for backwards compatibility. Please use the symbols from the +// c10 namespace where possible. +namespace caffe2 { +using namespace c10; +} +namespace at { +using namespace c10; +} +namespace at::cuda { +using namespace c10::cuda; +} // namespace at::cuda + +// WARNING!!! THIS IS A GIANT HACK!!! +// This line means you cannot simultaneously include c10/hip +// and c10/cuda and then use them from the at::cuda namespace. +// This is true in practice, because HIPIFY works inplace on +// files in ATen/cuda, so it assumes that c10::hip is available +// from at::cuda. This namespace makes that happen. When +// HIPIFY is no longer out-of-place, we can switch the cuda +// here to hip and everyone is happy. +namespace at::cuda { +using namespace c10::hip; +} // namespace at::cuda + +namespace at::xpu { +using namespace c10::xpu; +} // namespace at::xpu + +// C10_LIKELY/C10_UNLIKELY +// +// These macros provide parentheses, so you can use these macros as: +// +// if C10_LIKELY(some_expr) { +// ... +// } +// +// NB: static_cast to boolean is mandatory in C++, because __builtin_expect +// takes a long argument, which means you may trigger the wrong conversion +// without it. +// +#if defined(__GNUC__) || defined(__ICL) || defined(__clang__) +#define C10_LIKELY(expr) (__builtin_expect(static_cast(expr), 1)) +#define C10_UNLIKELY(expr) (__builtin_expect(static_cast(expr), 0)) +#else +#define C10_LIKELY(expr) (expr) +#define C10_UNLIKELY(expr) (expr) +#endif + +/// C10_NOINLINE - Functions whose declaration is annotated with this will not +/// be inlined. +#ifdef __GNUC__ +#define C10_NOINLINE __attribute__((noinline)) +#elif _MSC_VER +#define C10_NOINLINE __declspec(noinline) +#else +#define C10_NOINLINE +#endif + +#if defined(_MSC_VER) +#define C10_ALWAYS_INLINE __forceinline +#elif __has_attribute(always_inline) || defined(__GNUC__) +#define C10_ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define C10_ALWAYS_INLINE inline +#endif + +// Unlike C10_ALWAYS_INLINE, C10_ALWAYS_INLINE_ATTRIBUTE can be used +// on a lambda. +#if defined(_MSC_VER) +// MSVC 14.39 is reasonably recent and doesn't like +// [[msvc::forceinline]] on a lambda, so don't try to use it. +#define C10_ALWAYS_INLINE_ATTRIBUTE +#elif __has_attribute(always_inline) || defined(__GNUC__) +#define C10_ALWAYS_INLINE_ATTRIBUTE __attribute__((__always_inline__)) +#else +#define C10_ALWAYS_INLINE_ATTRIBUTE +#endif + +#if defined(_MSC_VER) +#define C10_ATTR_VISIBILITY_HIDDEN +#elif defined(__GNUC__) +#define C10_ATTR_VISIBILITY_HIDDEN __attribute__((__visibility__("hidden"))) +#else +#define C10_ATTR_VISIBILITY_HIDDEN +#endif + +#define C10_ERASE C10_ALWAYS_INLINE C10_ATTR_VISIBILITY_HIDDEN + +#include + +#ifdef __HIPCC__ +// Unlike CUDA, HIP requires a HIP header to be included for __host__ to work. +// We do this #include here so that C10_HOST_DEVICE and friends will Just Work. +// See https://github.com/ROCm/hip/issues/441 +#include +#endif + +#if defined(__CUDACC__) || defined(__HIPCC__) +// Designates functions callable from the host (CPU) and the device (GPU) +#define C10_HOST_DEVICE __host__ __device__ +#define C10_DEVICE __device__ +#define C10_HOST __host__ +// constants from +// (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications) +// The maximum number of threads per multiprocessor is 1024 for Turing +// architecture (7.5), 1536 for Geforce Ampere (8.6)/Jetson Orin (8.7), and +// 2048 for all other architectures. You'll get warnings if you exceed these +// constants. Hence, the following macros adjust the input values from the user +// to resolve potential warnings. +#if __CUDA_ARCH__ == 750 +constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1024; +#elif __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 890 +constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 1536; +#else +constexpr uint32_t CUDA_MAX_THREADS_PER_SM = 2048; +#endif +// CUDA_MAX_THREADS_PER_BLOCK is same for all architectures currently +constexpr uint32_t CUDA_MAX_THREADS_PER_BLOCK = 1024; +// CUDA_THREADS_PER_BLOCK_FALLBACK is the "canonical fallback" choice of block +// size. 256 is a good number for this fallback and should give good occupancy +// and versatility across all architectures. +constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; +// NOTE: if you are thinking of constexpr-ify the inputs to launch bounds, it +// turns out that although __launch_bounds__ can take constexpr, it +// can't take a constexpr that has anything to do with templates. +// Currently we use launch_bounds that depend on template arguments in +// Loops.cuh, Reduce.cuh and LossCTC.cuh. Hence, C10_MAX_THREADS_PER_BLOCK +// and C10_MIN_BLOCKS_PER_SM are kept as macros. +// Suppose you were planning to write __launch_bounds__(a, b), based on your +// performance tuning on a modern GPU. Instead, you should write +// __launch_bounds__(C10_MAX_THREADS_PER_BLOCK(a), C10_MIN_BLOCKS_PER_SM(a, b)), +// which will also properly respect limits on old architectures. +#define C10_MAX_THREADS_PER_BLOCK(val) \ + (((val) <= CUDA_MAX_THREADS_PER_BLOCK) ? (val) \ + : CUDA_THREADS_PER_BLOCK_FALLBACK) +#define C10_MIN_BLOCKS_PER_SM(threads_per_block, blocks_per_sm) \ + ((((threads_per_block) * (blocks_per_sm) <= CUDA_MAX_THREADS_PER_SM) \ + ? (blocks_per_sm) \ + : ((CUDA_MAX_THREADS_PER_SM + (threads_per_block) - 1) / \ + (threads_per_block)))) +// C10_LAUNCH_BOUNDS is analogous to __launch_bounds__ +#define C10_LAUNCH_BOUNDS_0 \ + __launch_bounds__( \ + 256, 4) // default launch bounds that should give good occupancy and + // versatility across all architectures. +#define C10_LAUNCH_BOUNDS_1(max_threads_per_block) \ + __launch_bounds__((C10_MAX_THREADS_PER_BLOCK((max_threads_per_block)))) +#define C10_LAUNCH_BOUNDS_2(max_threads_per_block, min_blocks_per_sm) \ + __launch_bounds__( \ + (C10_MAX_THREADS_PER_BLOCK((max_threads_per_block))), \ + (C10_MIN_BLOCKS_PER_SM((max_threads_per_block), (min_blocks_per_sm)))) +#else +#define C10_HOST_DEVICE +#define C10_HOST +#define C10_DEVICE +#endif + +#if defined(USE_ROCM) +#define C10_HIP_HOST_DEVICE __host__ __device__ +#else +#define C10_HIP_HOST_DEVICE +#endif + +#if defined(USE_ROCM) +// C10_WARP_SIZE is only allowed for device code. +// Host code _must_ use at::cuda::warp_size() +// HIP header used to define warpSize as a constexpr that was either 32 or 64 +// depending on the target device, and then always set it to 64 for host code. +// Host pass of HIP compiler needs C10_WARP_SIZE defined to _something_ so we +// set it to something unreasonable to trigger obvious host code errors. + +namespace at::cuda { +TORCH_CUDA_CPP_API int warp_size(); +} +#ifdef __HIPCC__ +static inline int __host__ C10_WARP_SIZE_INTERNAL() { + return at::cuda::warp_size(); +} + +static inline constexpr int __device__ C10_WARP_SIZE_INTERNAL() { +#if defined(__GFX9__) + return 64; +#else // __GFX9__ + return 32; +#endif // __GFX9__ +} +#else // __HIPCC__ +static inline int C10_WARP_SIZE_INTERNAL() { + return at::cuda::warp_size(); +} +#endif // __HIPCC__ + +#define C10_WARP_SIZE (C10_WARP_SIZE_INTERNAL()) +#define C10_WARP_SIZE_STATIC 64 + +#else // defined(USE_ROCM) +#define C10_WARP_SIZE 32 +#endif + +#if defined(_MSC_VER) && _MSC_VER <= 1900 +#define __func__ __FUNCTION__ +#endif + +// CUDA_KERNEL_ASSERT checks the assertion +// even when NDEBUG is defined. This is useful for important assertions in CUDA +// code that would otherwise be suppressed when building Release. +#if defined(__ANDROID__) || defined(__APPLE__) || defined(__FreeBSD__) +// Those platforms do not support assert() +#define CUDA_KERNEL_ASSERT(cond) +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) +#define SYCL_KERNEL_ASSERT(cond) +#elif defined(_MSC_VER) +#if defined(NDEBUG) +extern "C" { +C10_IMPORT +#if defined(__SYCL_DEVICE_ONLY__) +extern SYCL_EXTERNAL void _wassert( + const wchar_t* wexpr, + const wchar_t* wfile, + unsigned line); +#else +#if defined(__CUDA_ARCH__) +__host__ __device__ +#endif // __CUDA_ARCH__ + void + _wassert(wchar_t const* _Message, wchar_t const* _File, unsigned _Line); +#endif // __SYCL_DEVICE_ONLY__ +} +#endif // NDEBUG +#define CUDA_KERNEL_ASSERT(cond) \ + if (C10_UNLIKELY(!(cond))) { \ + (void)(_wassert( \ + _CRT_WIDE(#cond), \ + _CRT_WIDE(__FILE__), \ + static_cast(__LINE__)), \ + 0); \ + } +// TODO: This doesn't assert the message because I (chilli) couldn't figure out +// a nice way to convert a char* to a wchar_t* +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ + if (C10_UNLIKELY(!(cond))) { \ + (void)(_wassert( \ + _CRT_WIDE(#cond), \ + _CRT_WIDE(__FILE__), \ + static_cast(__LINE__)), \ + 0); \ + } +#define SYCL_KERNEL_ASSERT(cond) \ + if (C10_UNLIKELY(!(cond))) { \ + (void)(_wassert( \ + _CRT_WIDE(#cond), \ + _CRT_WIDE(__FILE__), \ + static_cast(__LINE__)), \ + 0); \ + } +#else // __APPLE__, _MSC_VER +#if defined(NDEBUG) +extern "C" { +#if defined(__SYCL_DEVICE_ONLY__) +extern SYCL_EXTERNAL void __assert_fail( + const char* expr, + const char* file, + unsigned int line, + const char* func); +#else // __SYCL_DEVICE_ONLY__ +#if (defined(__CUDA_ARCH__) && !(defined(__clang__) && defined(__CUDA__))) +// CUDA supports __assert_fail function which are common for both device +// and host side code. +__host__ __device__ +#endif + + // This forward declaration matching the declaration of __assert_fail + // exactly how it is in glibc in case parts of the program are compiled with + // different NDEBUG settings. Otherwise we might get 'ambiguous declaration' + // error. Note: On ROCm - this declaration serves for host side compilation. + void + __assert_fail( + const char* assertion, + const char* file, + unsigned int line, + const char* function) noexcept __attribute__((__noreturn__)); + +#endif // __SYCL_DEVICE_ONLY__ +} +#endif // NDEBUG +// ROCm disables kernel assert by default for performance considerations. +// Though ROCm supports __assert_fail, it uses kernel printf which has +// a non-negligible performance impact even if the assert condition is +// never triggered. We choose to use abort() instead which will still +// terminate the application but without a more useful error message. +#if !defined(C10_USE_ROCM_KERNEL_ASSERT) and defined(USE_ROCM) +#define CUDA_KERNEL_ASSERT(cond) \ + if C10_UNLIKELY (!(cond)) { \ + abort(); \ + } +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ + if C10_UNLIKELY (!(cond)) { \ + abort(); \ + } +#define SYCL_KERNEL_ASSERT(cond) \ + if C10_UNLIKELY (!(cond)) { \ + abort(); \ + } +#else +#define CUDA_KERNEL_ASSERT(cond) \ + if (C10_UNLIKELY(!(cond))) { \ + __assert_fail( \ + #cond, __FILE__, static_cast(__LINE__), __func__); \ + } +#define CUDA_KERNEL_ASSERT_MSG(cond, msg) \ + if (C10_UNLIKELY(!(cond))) { \ + __assert_fail( \ + msg, __FILE__, static_cast(__LINE__), __func__); \ + } +#define SYCL_KERNEL_ASSERT(cond) \ + if (C10_UNLIKELY(!(cond))) { \ + __assert_fail( \ + #cond, __FILE__, static_cast(__LINE__), __func__); \ + } +#endif // C10_USE_ROCM_KERNEL_ASSERT and USE_ROCM +#endif // __APPLE__ + +#ifdef __APPLE__ +#include +#endif + +#if defined(__ANDROID__) +#define C10_ANDROID 1 +#define C10_MOBILE 1 +#elif ( \ + defined(__APPLE__) && \ + (TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE)) +#define C10_IOS 1 +#define C10_MOBILE 1 +#endif // ANDROID / IOS + +#if defined(C10_MOBILE) && C10_MOBILE +#define C10_ALWAYS_INLINE_UNLESS_MOBILE inline +#else +#define C10_ALWAYS_INLINE_UNLESS_MOBILE C10_ALWAYS_INLINE +#endif + +#if !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) +#define CONSTEXPR_EXCEPT_WIN_CUDA constexpr +#define C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA constexpr + +#define STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(field, val) \ + static constexpr const char field[] = val; +#define STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(cls, field, val) +#endif // !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) + +#ifndef HAS_DEMANGLE +#if defined(__ANDROID__) || defined(_WIN32) || defined(__EMSCRIPTEN__) +#define HAS_DEMANGLE 0 +#elif defined(__APPLE__) && \ + (TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE) +#define HAS_DEMANGLE 0 +#else +#define HAS_DEMANGLE 1 +#endif +#endif // HAS_DEMANGLE + +#define _C10_PRAGMA__(string) _Pragma(#string) +#define _C10_PRAGMA_(string) _C10_PRAGMA__(string) + +#ifdef __clang__ +#define C10_CLANG_DIAGNOSTIC_PUSH() _Pragma("clang diagnostic push") +#define C10_CLANG_DIAGNOSTIC_POP() _Pragma("clang diagnostic pop") +#define C10_CLANG_DIAGNOSTIC_IGNORE(flag) \ + _C10_PRAGMA_(clang diagnostic ignored flag) +#define C10_CLANG_HAS_WARNING(flag) __has_warning(flag) +#else +#define C10_CLANG_DIAGNOSTIC_PUSH() +#define C10_CLANG_DIAGNOSTIC_POP() +#define C10_CLANG_DIAGNOSTIC_IGNORE(flag) +#define C10_CLANG_HAS_WARNING(flag) 0 +#endif + +#ifdef __clang__ + +#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) \ + _C10_PRAGMA_(clang diagnostic push) \ + _C10_PRAGMA_(clang diagnostic ignored "-Wunknown-warning-option") \ + _C10_PRAGMA_(clang diagnostic ignored warning) + +#define C10_DIAGNOSTIC_POP() _C10_PRAGMA_(clang diagnostic pop) + +#elif __GNUC__ + +#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) \ + _C10_PRAGMA_(GCC diagnostic push) \ + _C10_PRAGMA_(GCC diagnostic ignored "-Wpragmas") \ + _C10_PRAGMA_(GCC diagnostic ignored warning) + +#define C10_DIAGNOSTIC_POP() _C10_PRAGMA_(GCC diagnostic pop) + +#else + +#define C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(warning) +#define C10_DIAGNOSTIC_POP() + +#endif + +// This macro is used to find older C++ compilers +// that don't support move optimization for return values. + +#if (defined(__GNUC__) && __GNUC__ < 13) || \ + (defined(__clang_major__) && __clang_major__ < 13) +#define C10_RETURN_MOVE_IF_OLD_COMPILER 1 +#else +#define C10_RETURN_MOVE_IF_OLD_COMPILER 0 +#endif + +#endif // C10_MACROS_MACROS_H_ diff --git a/torch/headeronly/macros/build.bzl b/torch/headeronly/macros/build.bzl new file mode 100644 index 000000000000..00d31a40163c --- /dev/null +++ b/torch/headeronly/macros/build.bzl @@ -0,0 +1,29 @@ +def define_targets(rules): + rules.cc_library( + name = "macros", + srcs = [":cmake_macros_h"], + hdrs = [ + # Following the example from c10 + "Export.h", + "Macros.h", + ], + linkstatic = True, + local_defines = ["C10_BUILD_MAIN_LIB"], + visibility = ["//visibility:public"], + ) + + rules.cmake_configure_file( + name = "cmake_macros_h", + src = "cmake_macros.h.in", + out = "cmake_macros.h", + definitions = [ + "C10_BUILD_SHARED_LIBS", + "C10_USE_MSVC_STATIC_RUNTIME", + ] + rules.select({ + "//c10:using_gflags": ["C10_USE_GFLAGS"], + "//conditions:default": [], + }) + rules.select({ + "//c10:using_glog": ["C10_USE_GLOG"], + "//conditions:default": [], + }), + ) diff --git a/c10/macros/cmake_configure_file.bzl b/torch/headeronly/macros/cmake_configure_file.bzl similarity index 100% rename from c10/macros/cmake_configure_file.bzl rename to torch/headeronly/macros/cmake_configure_file.bzl diff --git a/c10/macros/cmake_macros.h.in b/torch/headeronly/macros/cmake_macros.h.in similarity index 80% rename from c10/macros/cmake_macros.h.in rename to torch/headeronly/macros/cmake_macros.h.in index 76c185b55236..e624221202df 100644 --- a/c10/macros/cmake_macros.h.in +++ b/torch/headeronly/macros/cmake_macros.h.in @@ -2,7 +2,7 @@ #define C10_MACROS_CMAKE_MACROS_H_ // Automatically generated header file for the C10 library. -// Do not include this file directly. Instead, include c10/macros/Macros.h. +// Do not include this file directly. Instead, include torch/headeronly/macros/Macros.h. #cmakedefine C10_BUILD_SHARED_LIBS #cmakedefine C10_USE_GLOG diff --git a/torch/headeronly/ovrsource_defs.bzl b/torch/headeronly/ovrsource_defs.bzl index 55e1947b5e76..6d1051fed2e4 100644 --- a/torch/headeronly/ovrsource_defs.bzl +++ b/torch/headeronly/ovrsource_defs.bzl @@ -1,3 +1,4 @@ +load("//arvr/tools/build_defs:genrule_utils.bzl", "gen_cmake_header") load("//arvr/tools/build_defs:oxx.bzl", "oxx_static_library") cpu_supported_platforms = [ @@ -18,29 +19,78 @@ def define_torch_headeronly_ovrsource(name, is_mobile): oxx_static_library( name = name, - srcs = [] + srcs = [], compatible_with = cpu_supported_platforms, compiler_flags = select({ "DEFAULT": [], }), - include_directories = [".."], - preprocessor_flags = [], + preprocessor_flags = ["-DC10_BUILD_MAIN_LIB=1",], fbobjc_compiler_flags = [], - public_include_directories = [".."], + public_include_directories = ["../.."], public_preprocessor_flags = pp_flags, public_raw_headers = native.glob([ "macros/*.h", + "util/*.h", ]), reexport_all_header_dependencies = False, visibility = [ - "//xplat/caffe2/torch/headeronly:torch_headeronly", + "//xplat/caffe2/torch/headeronly:torch_headeronly_ovrsource", + ], + exported_deps = [ + ":ovrsource_torch_headeronly_cmake_macros.h", + ], + ) + +def define_ovrsource_targets(): + common_c10_cmake_defines = [ + ("#cmakedefine C10_BUILD_SHARED_LIBS", ""), + ("#cmakedefine C10_USE_NUMA", ""), + ("#cmakedefine C10_USE_MSVC_STATIC_RUNTIME", ""), + ("#cmakedefine C10_USE_ROCM_KERNEL_ASSERT", ""), + ] + + mobile_c10_cmake_defines = [ + ("#cmakedefine C10_USE_GLOG", ""), + ("#cmakedefine C10_USE_GFLAGS", ""), + ] + + non_mobile_c10_cmake_defines = [ + ("#cmakedefine C10_USE_GLOG", "#define C10_USE_GLOG 1"), + ("#cmakedefine C10_USE_GFLAGS", "#define C10_USE_GFLAGS 1"), + ] + + gen_cmake_header( + src = "macros/cmake_macros.h.in", + defines = common_c10_cmake_defines + mobile_c10_cmake_defines, + header = "torch/headeronly/macros/cmake_macros.h", + prefix = "ovrsource_torch_headeronly_mobile_", + ) + + gen_cmake_header( + src = "macros/cmake_macros.h.in", + defines = common_c10_cmake_defines + non_mobile_c10_cmake_defines, + header = "torch/headeronly/macros/cmake_macros.h", + prefix = "ovrsource_torch_headeronly_non_mobile_", + ) + + oxx_static_library( + name = "ovrsource_torch_headeronly_cmake_macros.h", + compatible_with = [ + "ovr_config//os:android", + "ovr_config//os:iphoneos", + "ovr_config//os:linux", + "ovr_config//os:macos", + "ovr_config//os:windows", ], deps = select({ - "DEFAULT": [], + "ovr_config//os:android": [":ovrsource_torch_headeronly_mobile_cmake_macros.h"], + "ovr_config//os:iphoneos": [":ovrsource_torch_headeronly_mobile_cmake_macros.h"], + "ovr_config//os:linux": [":ovrsource_torch_headeronly_non_mobile_cmake_macros.h"], + "ovr_config//os:macos": [":ovrsource_torch_headeronly_non_mobile_cmake_macros.h"], + "ovr_config//os:windows": [":ovrsource_torch_headeronly_non_mobile_cmake_macros.h"], }), ) -def define_ovrsource_targets(): oxx_static_library( name = "torch_headeronly_ovrsource", compatible_with = cpu_supported_platforms, diff --git a/torch/headeronly/util/Exception.h b/torch/headeronly/util/Exception.h new file mode 100644 index 000000000000..c5d05e0fa955 --- /dev/null +++ b/torch/headeronly/util/Exception.h @@ -0,0 +1,83 @@ +#pragma once + +#include +#include + +#include +#include + +namespace c10 { +// On nvcc, C10_UNLIKELY thwarts missing return statement analysis. In cases +// where the unlikely expression may be a constant, use this macro to ensure +// return statement analysis keeps working (at the cost of not getting the +// likely/unlikely annotation on nvcc). +// https://github.com/pytorch/pytorch/issues/21418 +// +// Currently, this is only used in the error reporting macros below. If you +// want to use it more generally, move me to Macros.h +// +// TODO: Brian Vaughan observed that we might be able to get this to work on +// nvcc by writing some sort of C++ overload that distinguishes constexpr inputs +// from non-constexpr. Since there isn't any evidence that losing C10_UNLIKELY +// in nvcc is causing us perf problems, this is not yet implemented, but this +// might be an interesting piece of C++ code for an intrepid bootcamper to +// write. +#if defined(__CUDACC__) +#define C10_UNLIKELY_OR_CONST(e) e +#else +#define C10_UNLIKELY_OR_CONST(e) C10_UNLIKELY(e) +#endif + +} // namespace c10 + +// STD_TORCH_CHECK throws std::runtime_error instead of c10::Error which is +// useful when certain headers are used in a libtorch-independent way, +// e.g. when Vectorized is used in AOTInductor generated code, or +// for custom ops to have an ABI stable dependency on libtorch. +#ifdef STRIP_ERROR_MESSAGES +#define STD_TORCH_CHECK_MSG(cond, type, ...) \ + (#cond #type " CHECK FAILED at " C10_STRINGIZE(__FILE__)) +#else // so STRIP_ERROR_MESSAGES is not defined +namespace torch::headeronly::detail { +template +std::string stdTorchCheckMsgImpl(const char* /*msg*/, const Args&... args) { + // This is similar to the one in c10/util/Exception.h, but does + // not depend on the more complex c10::str() function. ostringstream + // supports fewer data types than c10::str(), but should be sufficient + // in the headeronly world. + std::ostringstream oss; + ((oss << args), ...); + return oss.str(); +} + +inline const char* stdTorchCheckMsgImpl(const char* msg) { + return msg; +} +// If there is just 1 user-provided C-string argument, use it. +inline const char* stdTorchCheckMsgImpl(const char* /*msg*/, const char* args) { + return args; +} +} // namespace torch::headeronly::detail + +#define STD_TORCH_CHECK_MSG(cond, type, ...) \ + (torch::headeronly::detail::stdTorchCheckMsgImpl( \ + "Expected " #cond \ + " to be true, but got false. " \ + "(Could this error message be improved? If so, " \ + "please report an enhancement request to PyTorch.)", \ + ##__VA_ARGS__)) +#endif // STRIP_ERROR_MESSAGES + +#define STD_TORCH_CHECK(cond, ...) \ + if (C10_UNLIKELY_OR_CONST(!(cond))) { \ + throw std::runtime_error(STD_TORCH_CHECK_MSG( \ + cond, \ + "", \ + __func__, \ + ", ", \ + __FILE__, \ + ":", \ + __LINE__, \ + ", ", \ + ##__VA_ARGS__)); \ + } diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 79442f57d306..ccd967d69f4e 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -704,10 +704,7 @@ def _reconstruct(self, cpp_module): @property def graph(self): - r"""Return a string representation of the internal graph for the ``forward`` method. - - See :ref:`interpreting-graphs` for details. - """ + r"""Return a string representation of the internal graph for the ``forward`` method.""" return self._c._get_method("forward").graph @property @@ -716,7 +713,6 @@ def inlined_graph(self): Return a string representation of the internal graph for the ``forward`` method. This graph will be preprocessed to inline all function and method calls. - See :ref:`interpreting-graphs` for details. """ return self.forward.inlined_graph # type: ignore[attr-defined] @@ -725,7 +721,6 @@ def code(self): r""" Return a pretty-printed representation (as valid Python syntax) of the internal graph for the ``forward`` method. - See :ref:`inspecting-code` for details. """ return self.forward.code # type: ignore[attr-defined] @@ -740,7 +735,6 @@ def code_with_constants(self): [1] a ConstMap following the CONSTANT.cN format of the output in [0]. The indices in the [0] output are keys to the underlying constant's values. - See :ref:`inspecting-code` for details. """ r = self.forward.code_with_constants # type: ignore[attr-defined] return (r[0], ConstMap(r[1])) @@ -1246,7 +1240,7 @@ def script( subsequently passed by reference between Python and TorchScript with zero copy overhead. ``torch.jit.script`` can be used as a function for modules, functions, dictionaries and lists - and as a decorator ``@torch.jit.script`` for :ref:`torchscript-classes` and functions. + and as a decorator ``@torch.jit.script`` for torchscript-classes and functions. Args: obj (Callable, class, or nn.Module): The ``nn.Module``, function, class type, diff --git a/torch/jit/supported_ops.py b/torch/jit/supported_ops.py index 791a11a9b3aa..98229edff6ee 100644 --- a/torch/jit/supported_ops.py +++ b/torch/jit/supported_ops.py @@ -243,8 +243,8 @@ def _get_global_builtins(): "getattr": "Attribute name must be a literal string", "hasattr": "Attribute name must be a literal string", "isinstance": "Result is static", - "zip": "Arguments must be iterable. See :ref:`Iterables ` for details.", - "enumerate": "Arguments must be iterable. See :ref:`Iterables ` for details.", + "zip": "Arguments must be iterable.", + "enumerate": "Arguments must be iterable.", "range": "Can only be used as an iterator in a for loop", } @@ -295,7 +295,7 @@ def _get_global_builtins(): {schemaless_ops_str} -The following functions will use the corresponding magic method on :any:`TorchScript classes` +The following functions will use the corresponding magic method on TorchScript classes .. csv-table:: :header: "Function", "Magic Method" diff --git a/torch/masked/maskedtensor/_ops_refs.py b/torch/masked/maskedtensor/_ops_refs.py index 719df7eac464..8135f149a1bf 100644 --- a/torch/masked/maskedtensor/_ops_refs.py +++ b/torch/masked/maskedtensor/_ops_refs.py @@ -351,7 +351,10 @@ def _apply_fn_on_data(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten._to_copy]) def _to_copy(func, *args, **kwargs): new_data = func(_get_data(args[0]), *args[1:], **kwargs) - return MaskedTensor(new_data, _maybe_get_mask(args[0])) + cloned_kwargs = kwargs.copy() + cloned_kwargs["dtype"] = torch.bool + new_mask = func(_maybe_get_mask(args[0]), *args[1:], **cloned_kwargs) + return MaskedTensor(new_data, new_mask) @register_dispatch_func([torch.ops.aten._softmax]) diff --git a/torch/nativert/detail/ITree.cpp b/torch/nativert/detail/ITree.cpp index 123ee4498d06..cd24ca78320f 100644 --- a/torch/nativert/detail/ITree.cpp +++ b/torch/nativert/detail/ITree.cpp @@ -46,7 +46,7 @@ class PytreeNodeRegistry { const ITreeSpec& spec, std::vector& ivalues) { const auto& tuple = nested.toTupleRef().elements(); - TORCH_CHECK_EQ(tuple.size(), spec.children().size()); + TORCH_CHECK(tuple.size() == spec.children().size()); for (size_t i = 0; i < tuple.size(); i++) { itreeFlatten(tuple[i], spec.children(i), ivalues); } @@ -60,7 +60,7 @@ class PytreeNodeRegistry { const c10::IValue& nested, const ITreeSpec& spec) { const auto& tuple = nested.toTupleRef().elements(); - TORCH_CHECK_EQ(tuple.size(), spec.children().size()); + TORCH_CHECK(tuple.size() == spec.children().size()); for (size_t i = 0; i < tuple.size(); i++) { ivalueApply(fn, tuple[i], spec.children(i)); } @@ -119,7 +119,7 @@ class PytreeNodeRegistry { const auto& contextKeys = spec.contextKeys(); // allow the dict size less than the spec, missing key will be // filled with empty tensor - TORCH_CHECK_LE(dict.size(), contextKeys.size()); + TORCH_CHECK(dict.size() <= contextKeys.size()); size_t i = 0; for (const auto& key : contextKeys) { auto it = dict.find(key); @@ -143,7 +143,7 @@ class PytreeNodeRegistry { c10::Dict dict( c10::AnyType::get(), c10::AnyType::get()); TORCH_CHECK(obj.is_array()); - TORCH_CHECK_EQ(obj.size(), flats.size()); + TORCH_CHECK(obj.size() == flats.size()); dict.reserve(flats.size()); for (size_t i = 0; i < flats.size(); i++) { dict.insert(dynamicToIValue(obj[i]), std::move(flats[i])); @@ -200,7 +200,7 @@ ITreeSpec makeITreeSpec( TORCH_CHECK(obj.is_object()); TORCH_CHECK(obj.find("type") != obj.end()); if (obj["type"].is_null()) { - TORCH_CHECK_EQ(obj["children_spec"].size(), 0); + TORCH_CHECK(obj["children_spec"].empty()); TORCH_CHECK(obj["context"].is_null()); const Value* value = values[start]; @@ -244,11 +244,11 @@ ITreeSpec itreeSpecLoads( const std::vector& values) { const auto obj = nlohmann::json::parse(json); TORCH_CHECK(obj.is_array()); - TORCH_CHECK_EQ(obj.size(), 2); - TORCH_CHECK_EQ(obj[0].get(), kDefaultTreeSpecSerializationProtocol); + TORCH_CHECK(obj.size() == 2); + TORCH_CHECK(obj[0].get() == kDefaultTreeSpecSerializationProtocol); auto result = makeITreeSpec(obj[1], values, 0); - TORCH_CHECK_EQ(result.numIValues(), values.size()); + TORCH_CHECK(result.numIValues() == values.size()); return result; } @@ -256,7 +256,7 @@ c10::IValue itreeUnflatten( std::vector ivalues, const ITreeSpec& spec) { RECORD_USER_SCOPE("nativert::itreeUnflatten"); - TORCH_CHECK_EQ(ivalues.size(), spec.numIValues()); + TORCH_CHECK(ivalues.size() == spec.numIValues()); if (spec.isIValue()) { return std::move(ivalues[0]); } @@ -299,20 +299,20 @@ std::vector itreeFlattenFromArgs( const ITreeSpec& spec) { RECORD_USER_SCOPE("nativert::itreeFlattenFromArgs"); TORCH_CHECK(!spec.isIValue()); - TORCH_CHECK_EQ(spec.children().size(), 2); + TORCH_CHECK(spec.children().size() == 2); std::vector ivalues; ivalues.reserve(spec.numIValues()); const auto& specArgs = spec.children(0); TORCH_CHECK(!specArgs.isIValue()); - TORCH_CHECK_EQ(specArgs.children().size(), args.size()); + TORCH_CHECK(specArgs.children().size() == args.size()); for (size_t i = 0; i < args.size(); i++) { itreeFlatten(args[i], specArgs.children(i), ivalues); } const auto& specKwargs = spec.children(1); TORCH_CHECK(!specKwargs.isIValue()); - TORCH_CHECK_EQ(specKwargs.context().size(), kwargs.size()); + TORCH_CHECK(specKwargs.context().size() == kwargs.size()); for (size_t i = 0; i < specKwargs.context().size(); i++) { itreeFlatten( kwargs.at(specKwargs.context()[i].get_ref()), @@ -329,11 +329,11 @@ void ivalueApplyFromArgs( const ITreeSpec& spec) { RECORD_USER_SCOPE("nativert::ivalueApplyFromArgs"); TORCH_CHECK(!spec.isIValue()); - TORCH_CHECK_EQ(spec.children().size(), 2); + TORCH_CHECK(spec.children().size() == 2); const auto& specArgs = spec.children(0); TORCH_CHECK(!specArgs.isIValue()); - TORCH_CHECK_EQ(specArgs.children().size(), args.size()); + TORCH_CHECK(specArgs.children().size() == args.size()); for (size_t i = 0; i < args.size(); i++) { ivalueApply(fn, args[i], specArgs.children(i)); } @@ -342,7 +342,7 @@ void ivalueApplyFromArgs( TORCH_CHECK(!specKwargs.isIValue()); const auto& ctx = specKwargs.context(); - TORCH_CHECK_EQ(ctx.size(), kwargs.size()); + TORCH_CHECK(ctx.size() == kwargs.size()); for (size_t i = 0; i < ctx.size(); i++) { ivalueApply( diff --git a/torch/nativert/executor/ConstantFolder.cpp b/torch/nativert/executor/ConstantFolder.cpp index 7db1fd736243..13d253394805 100644 --- a/torch/nativert/executor/ConstantFolder.cpp +++ b/torch/nativert/executor/ConstantFolder.cpp @@ -24,8 +24,9 @@ namespace torch::nativert { void ConstantFolder::unlinkConstants( std::vector>& kernels) { - TORCH_CHECK_EQ(kernels.size(), graph_.nodes().size()) - << "graph node count and kernel count should be equal"; + TORCH_CHECK( + kernels.size() == graph_.nodes().size(), + "graph node count and kernel count should be equal"); unlinked_ = true; @@ -135,8 +136,9 @@ void ConstantFolder::unlinkConstants( */ void ConstantFolder::evaluate(Weights& weights) { - CHECK(unlinked_) - << "cannot evaluate weights for a graph whose constants have not been unlinked via ConstFolder::unlinkConstants"; + TORCH_CHECK( + unlinked_, + "cannot evaluate weights for a graph whose constants have not been unlinked via ConstFolder::unlinkConstants"); weights.validateAllWeightsLoaded(); diff --git a/torch/nativert/executor/DelegateExecutor.h b/torch/nativert/executor/DelegateExecutor.h index b8c3d506c431..7d88f9898776 100644 --- a/torch/nativert/executor/DelegateExecutor.h +++ b/torch/nativert/executor/DelegateExecutor.h @@ -46,6 +46,8 @@ class DelegateExecutor { // This call activate the processed weights. virtual void commitWeights() = 0; + virtual void initWeights(std::shared_ptr weights) = 0; + virtual std::vector run(std::vector& inputs) = 0; }; diff --git a/torch/nativert/executor/ExecutionFrame.h b/torch/nativert/executor/ExecutionFrame.h index ae8821a6e58b..4cf02054bc5e 100644 --- a/torch/nativert/executor/ExecutionFrame.h +++ b/torch/nativert/executor/ExecutionFrame.h @@ -46,13 +46,14 @@ class ExecutionFrame { } template - auto withMemoryPlanner(CB&& cb) { + auto withManagedMemory(CB&& cb) { if (!layoutManager_) { - return std::forward(cb)(); + return std::forward(cb)(nullptr); } LayoutManagerGuard guard(*layoutManager_); - return std::forward(cb)(); + return std::forward(cb)( + const_cast(layoutManager_.get())); } std::vector tryMoveUserOutputs(); @@ -123,8 +124,10 @@ class ExecutionFrame { } c10::intrusive_ptr getWork(int64_t workId) const { - CHECK(work_.find(workId) != work_.end()) - << "Couldn't find work with Id: " << workId; + TORCH_CHECK( + work_.find(workId) != work_.end(), + "Couldn't find work with Id: ", + workId); return work_.at(workId); } @@ -150,7 +153,7 @@ class ExecutionFrame { private: bool isOutputMovable(size_t idx) const { - TORCH_CHECK_LT(idx, moveable_output_mask_.size()); + TORCH_CHECK(idx < moveable_output_mask_.size()); return moveable_output_mask_[idx]; } diff --git a/torch/nativert/executor/Executor.cpp b/torch/nativert/executor/Executor.cpp index 285b6dea00dd..3a3f3d335137 100644 --- a/torch/nativert/executor/Executor.cpp +++ b/torch/nativert/executor/Executor.cpp @@ -19,40 +19,34 @@ namespace torch::nativert { Executor::Executor( torch::nativert::ExecutorConfig executorConfig, std::shared_ptr graph, - std::shared_ptr weights, - const Placement& placement, - std::shared_ptr pytorchStreamReader, - const MakeProxyExecutorFn& makeProxyExecutorFunc) + const std::shared_ptr& weights, + Placement placement, + const std::shared_ptr& + pytorchStreamReader) : executorConfig_(std::move(executorConfig)), graph_(std::move(graph)), - placement_(placement), + placement_(std::move(placement)), constantFolder_( executorConfig_.runConstFolding ? std::optional(*graph_) : std::nullopt), - makeProxyExecutorFunc_(makeProxyExecutorFunc), executionFrames_(executorConfig_.maxNumConcurrentThreads), clearedExecutionFrames_(executorConfig_.maxNumConcurrentThreads), numExecutionFrames_(0), lastClearedTimestamp_(getCurrentTimestampSeconds()) { if (weights) { - initialize(std::move(weights), std::move(pytorchStreamReader)); + initialize(weights, pytorchStreamReader); } } void Executor::initialize( - std::shared_ptr weights, - std::shared_ptr + const std::shared_ptr& weights, + const std::shared_ptr& pytorchStreamReader) { auto start = std::chrono::high_resolution_clock::now(); auto executionKernels = KernelFactory().initializeNodeKernels( - *graph_, - weights, - executorConfig_, - placement_, - std::move(pytorchStreamReader), - makeProxyExecutorFunc_); + *graph_, weights, executorConfig_, placement_, pytorchStreamReader); if (constantFolder_.has_value()) { constantFolder_->unlinkConstants(executionKernels.nodeKernels); @@ -71,9 +65,7 @@ void Executor::initialize( delegateExecutors_ = std::move(executionKernels.delegateExecutors); constFoldingExecutions_ = std::move(executionKernels.constFoldingExecutions); - // initialize weights_ - processWeights(weights); - atomicSwapWeights(weights); + initWeights(weights); if (executorConfig_.layoutPlannerSettings.enabled()) { layoutPlanner_ = std::make_unique( @@ -113,13 +105,14 @@ void Executor::atomicSwapWeights(std::shared_ptr weights) { } } -void Executor::maybeRunConstantFolding(std::shared_ptr weights) { +void Executor::maybeRunConstantFolding( + const std::shared_ptr& weights) { for (auto& execution : constFoldingExecutions_) { ExecutionFrame constFoldingFrame(execution.executor->graph()); std::vector inputs; inputs.reserve(graph_->signature().inputsToWeights().size()); for (const auto& [_, name] : graph_->signature().inputsToWeights()) { - inputs.push_back(weights->at(name)); + inputs.emplace_back(weights->at(name)); } auto outputs = execution.executor->execute(constFoldingFrame, inputs); @@ -130,7 +123,7 @@ void Executor::maybeRunConstantFolding(std::shared_ptr weights) { } } -void Executor::processWeights(std::shared_ptr weights) { +void Executor::processWeights(const std::shared_ptr& weights) { maybeRunConstantFolding(weights); if (constantFolder_.has_value()) { constantFolder_->evaluate(*weights); @@ -140,20 +133,41 @@ void Executor::processWeights(std::shared_ptr weights) { } } +void Executor::initWeights(const std::shared_ptr& weights) { + maybeRunConstantFolding(weights); + if (constantFolder_.has_value()) { + constantFolder_->evaluate(*weights); + } + + weights_.withLock([&](auto& w) { w = std::move(weights); }); + + for (auto& delegateExecutor : delegateExecutors_) { + delegateExecutor->initWeights(weights); + } +} + namespace { void validateInput( const std::string& inputName, const at::Tensor& inputTensor, const torch::nativert::TensorMeta& tensorValueMeta) { - CHECK(inputTensor.dtype() == tensorValueMeta.dtype()) - << "Input tensor dtype mismatch for " << inputName << ", expecting " - << c10::toString(tensorValueMeta.dtype()) << " but got " - << inputTensor.dtype().name(); - - CHECK(inputTensor.device() == tensorValueMeta.device()) - << "Input tensor device mismatch for " << inputName << ", expecting " - << tensorValueMeta.device().str() << " but got " - << inputTensor.device().str(); + TORCH_CHECK( + inputTensor.dtype() == tensorValueMeta.dtype(), + "Input tensor dtype mismatch for ", + inputName, + ", expecting ", + c10::toString(tensorValueMeta.dtype()), + " but got ", + inputTensor.dtype().name()); + + TORCH_CHECK( + inputTensor.device() == tensorValueMeta.device(), + "Input tensor device mismatch for ", + inputName, + ", expecting ", + tensorValueMeta.device().str(), + " but got ", + inputTensor.device().str()); } } // namespace @@ -167,8 +181,11 @@ void Executor::validateInputs(const std::vector& inputs) const { if (actualInput.isTensor()) { const auto& inputName = std::string(inputValues[i]->name()); auto it = tensorValuesMeta.find(inputName); - CHECK(it != tensorValuesMeta.end()) - << "Couldn't find " << inputName << " in tensorValuesMeta"; + TORCH_CHECK( + it != tensorValuesMeta.end(), + "Couldn't find ", + inputName, + " in tensorValuesMeta"); validateInput(inputName, actualInput.toTensor(), it->second); } } @@ -289,15 +306,17 @@ void Executor::returnExecutorFrameToPool( // Create an entry with used=true if (C10_UNLIKELY(!clearingInProgress_)) { - CHECK(executionFrames_.writeIfNotFull(std::move(frame))) - << "ExecutionFrame pool full"; + TORCH_CHECK( + executionFrames_.writeIfNotFull(std::move(frame)), + "ExecutionFrame pool full"); } else { ExecutionFrameEntry frameEntry; frameEntry.used = true; frameEntry.frame = std::move(frame); - CHECK(clearedExecutionFrames_.writeIfNotFull(std::move(frameEntry))) - << "Cleared ExecutionFrame pool full"; + TORCH_CHECK( + clearedExecutionFrames_.writeIfNotFull(std::move(frameEntry)), + "Cleared ExecutionFrame pool full"); } } catch (...) { sem_.release(); @@ -324,7 +343,7 @@ std::vector Executor::execute( std::optional> outputs; const auto userInputs = graph_->userInputs(); const auto& tensorValuesMeta = graph_->tensorValuesMeta(); - TORCH_CHECK_EQ(userInputs.size(), inputTreeSpec.numIValues()); + TORCH_CHECK(userInputs.size() == inputTreeSpec.numIValues()); auto executionFrameFillUserInputs = [&](const c10::IValue& leaf, const Value* value) { @@ -332,8 +351,11 @@ std::vector Executor::execute( if (executorConfig_.validateInputs && leaf.isTensor()) { const auto& inputName = std::string(value->name()); auto it = tensorValuesMeta.find(inputName); - CHECK(it != tensorValuesMeta.end()) - << "Couldn't find " << inputName << " in tensorValuesMeta"; + TORCH_CHECK( + it != tensorValuesMeta.end(), + "Couldn't find ", + inputName, + " in tensorValuesMeta"); validateInput(inputName, leaf.toTensor(), it->second); } executionFrame->setBorrowedIValue( @@ -352,11 +374,11 @@ std::vector Executor::execute( } ProfileMetrics Executor::benchmarkIndividualNodes( - std::vector> inputsList, + const std::vector>& inputsList, const uint32_t warmupRuns, const uint32_t mainRuns) { - CHECK(inputsList.size() > 0) << "Need at least one input to benchmark"; - CHECK(warmupRuns >= 1 && mainRuns >= 1) << "Need at least one run"; + TORCH_CHECK(!inputsList.empty(), "Need at least one input to benchmark"); + TORCH_CHECK(warmupRuns >= 1 && mainRuns >= 1, "Need at least one run"); for (const auto& inputs : inputsList) { if (executorConfig_.validateInputs) { @@ -378,8 +400,9 @@ int64_t Executor::getCurrentTimestampSeconds() const { std::vector Executor::getDelegates() { std::vector delegates; + delegates.reserve(delegateExecutors_.size()); for (const auto& delegateExecutor : delegateExecutors_) { - delegates.push_back(delegateExecutor.get()); + delegates.emplace_back(delegateExecutor.get()); } return delegates; } diff --git a/torch/nativert/executor/Executor.h b/torch/nativert/executor/Executor.h index db496ace926e..57356c36d6c5 100644 --- a/torch/nativert/executor/Executor.h +++ b/torch/nativert/executor/Executor.h @@ -79,11 +79,10 @@ class Executor { Executor( torch::nativert::ExecutorConfig executorConfig, std::shared_ptr graph, - std::shared_ptr weights, - const Placement& placement = Placement(), - std::shared_ptr - pytorchStreamReader = nullptr, - const MakeProxyExecutorFn& makeProxyExecutorFunc = nullptr); + const std::shared_ptr& weights, + Placement placement = Placement(), + const std::shared_ptr& + pytorchStreamReader = nullptr); std::shared_ptr getWeights() { std::shared_ptr ret; @@ -91,7 +90,7 @@ class Executor { return ret; } - void processWeights(std::shared_ptr weights); + void processWeights(const std::shared_ptr& weights); void atomicSwapWeights(std::shared_ptr weights); // This API only returns the flattened UserOutputs, @@ -106,7 +105,7 @@ class Executor { const ITreeSpec& inputTreeSpec); ProfileMetrics benchmarkIndividualNodes( - std::vector> inputsList, + const std::vector>& inputsList, const uint32_t warmupRuns, const uint32_t mainRuns); @@ -141,8 +140,8 @@ class Executor { c10::Synchronized> weights_; void initialize( - std::shared_ptr weights, - std::shared_ptr + const std::shared_ptr& weights, + const std::shared_ptr& pytorchStreamReader); ExecutorFramePtr getExecutorFrameFromPool(); @@ -171,12 +170,14 @@ class Executor { ExecutionFrameEntry& operator=(const ExecutionFrameEntry&) = delete; }; - void maybeRunConstantFolding(std::shared_ptr weights); + void maybeRunConstantFolding(const std::shared_ptr& weights); void validateInputs(const std::vector& inputs) const; // Helper method to get current timestamp in seconds int64_t getCurrentTimestampSeconds() const; + void initWeights(const std::shared_ptr& weights); + std::unique_ptr graphExecutor_; const Placement placement_; @@ -188,8 +189,6 @@ class Executor { std::optional constantFolder_; - MakeProxyExecutorFn makeProxyExecutorFunc_; - c10::Semaphore sem_; torch::nativert::detail::MPMCQueue> executionFrames_; diff --git a/torch/nativert/executor/GraphExecutorBase.cpp b/torch/nativert/executor/GraphExecutorBase.cpp index 5ad31a7dacab..9a527cc8117b 100644 --- a/torch/nativert/executor/GraphExecutorBase.cpp +++ b/torch/nativert/executor/GraphExecutorBase.cpp @@ -20,7 +20,7 @@ void GraphExecutorBase::fillUserInputs( std::vector inputs) { RECORD_USER_SCOPE("Executor::fillUserInputs"); const auto& inputValues = graph_.userInputs(); - TORCH_CHECK_EQ(inputValues.size(), inputs.size()); + TORCH_CHECK(inputValues.size() == inputs.size()); // load user input tensor into execution frame for (size_t i = 0; i < inputValues.size(); i++) { @@ -32,7 +32,7 @@ void GraphExecutorBase::fillUserInputs( ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes( ExecutionFrame& executionFrame, - std::vector> inputsList, + const std::vector>& inputsList, const uint32_t warmupRuns, const uint32_t mainRuns) { // TODO: add support for memory profiling @@ -40,6 +40,13 @@ ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes( ProfileMetrics results; const auto numNodes = static_cast(nodeKernels_.size()); + + results.percentPerNode.resize(numNodes, 0.0f); + results.nodeTypes.reserve(numNodes); + for (const auto& nodeKernel : nodeKernels_) { + results.nodeTypes.emplace_back(nodeKernel->node()->target()); + } + results.timePerNode.resize(numNodes, 0); if (inputsList.empty()) { auto i = 0; @@ -78,7 +85,7 @@ ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes( for (auto inputs : inputsList) { const auto& inputValues = graph_.userInputs(); - TORCH_CHECK_EQ(inputValues.size(), inputs.size()); + TORCH_CHECK(inputValues.size() == inputs.size()); for (size_t j = 0; j < inputValues.size(); j++) { executionFrame.setIValue(inputValues[j]->id(), std::move(inputs[j])); } @@ -112,7 +119,11 @@ ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes( results.totalNodesCount = numNodes; for (const auto& r : results.timePerNodeType) { const std::string& target = r.first; - results.percentPerNodeType[target] = r.second * 100.0 / results.totalTime; + results.percentPerNodeType[target] = r.second * 100.0f / results.totalTime; + } + for (const auto i : c10::irange(numNodes)) { + results.percentPerNode[i] = + results.timePerNode[i] * 100.0f / results.totalTime; } return results; } diff --git a/torch/nativert/executor/GraphExecutorBase.h b/torch/nativert/executor/GraphExecutorBase.h index 86c6ed61c1f9..dfe020ebae29 100644 --- a/torch/nativert/executor/GraphExecutorBase.h +++ b/torch/nativert/executor/GraphExecutorBase.h @@ -14,12 +14,15 @@ struct ProfileMetrics { size_t staticDispatchNodesCount{0}; size_t totalNodesCount{0}; std::vector timePerNode; + std::vector nodeTypes; std::unordered_map timePerNodeType; std::unordered_map percentPerNodeType; + std::vector percentPerNode; std::unordered_map instancesPerNodeType; std::unordered_set staticDispatchNodes; std::unordered_set primNodes; float totalTime{0}; + std::string name; }; /** @@ -51,7 +54,7 @@ class GraphExecutorBase { ProfileMetrics benchmarkIndividualNodes( ExecutionFrame& executionFrame, - std::vector> inputs, + const std::vector>& inputs, const uint32_t warmup_runs, const uint32_t main_runs); diff --git a/torch/nativert/executor/ParallelGraphExecutor.cpp b/torch/nativert/executor/ParallelGraphExecutor.cpp index c147d23873d3..b54b22228f97 100644 --- a/torch/nativert/executor/ParallelGraphExecutor.cpp +++ b/torch/nativert/executor/ParallelGraphExecutor.cpp @@ -22,11 +22,13 @@ ThreadPoolExecutor::~ThreadPoolExecutor() { } C10_ALWAYS_INLINE moodycamel::ProducerToken& ThreadPoolExecutor::ptok() { + // NOLINTNEXTLINE(misc-use-internal-linkage) thread_local moodycamel::ProducerToken ptok(*work_); return ptok; } C10_ALWAYS_INLINE moodycamel::ConsumerToken& ThreadPoolExecutor::ctok() { + // NOLINTNEXTLINE(misc-use-internal-linkage) thread_local moodycamel::ConsumerToken ctok(*work_); return ctok; } @@ -39,7 +41,7 @@ void ThreadPoolExecutor::execute_inline(SessionState* session, WorkUnit* unit) { void ThreadPoolExecutor::start(int32_t numThreads) { stopped_ = false; for (int32_t i = 0; i < numThreads; ++i) { - threads_.emplace_back(std::thread(&ThreadPoolExecutor::loop, this)); + threads_.emplace_back(&ThreadPoolExecutor::loop, this); } } @@ -62,16 +64,17 @@ void ThreadPoolExecutor::loop() { void ThreadPoolExecutor::add(SessionState* session, WorkUnit* unit) { session->addWork(); - work_->enqueue(ptok(), std::bind(&WorkUnit::run, unit, this, session)); + work_->enqueue(ptok(), [unit, this, session] { unit->run(this, session); }); sem_->release(); } void ThreadPoolExecutor::add( SessionState* session, - std::vector::const_iterator&& begin, - const std::vector::const_iterator&& end) { + std::vector::const_iterator begin, + const std::vector::const_iterator& end) { const auto count = end - begin; + // NOLINTNEXTLINE(bugprone-switch-missing-default-case) switch (count) { case 0: { return; @@ -86,16 +89,17 @@ void ThreadPoolExecutor::add( std::vector runnables; runnables.reserve(count); for (; begin != end; ++begin) { - runnables.push_back(std::bind(&WorkUnit::run, *begin, this, session)); + runnables.emplace_back( + [capture0 = *begin, this, session] { capture0->run(this, session); }); } work_->enqueue_bulk(ptok(), runnables.begin(), count); - sem_->release(count); + sem_->release(static_cast(count)); } void ThreadPoolExecutor::stop() { stopped_ = true; - sem_->release(threads_.size()); + sem_->release(static_cast(threads_.size())); std::for_each(threads_.begin(), threads_.end(), [](auto& t) { t.join(); }); threads_.clear(); @@ -136,10 +140,10 @@ void ThreadPoolExecutor::run( } void WorkUnit::run(ThreadPoolExecutor* executor, SessionState* session) { - thread_local std::vector newWorkUnits; - thread_local c10::InferenceMode mode; + /* thread_local */ std::vector newWorkUnits; + /* thread_local */ c10::InferenceMode mode; - WorkUnit* unit = this; + /* thread_local */ WorkUnit* unit = this; while (true) { unit->kernel->compute(session->frame()); @@ -219,7 +223,7 @@ ParallelGraphExecutor::ParallelGraphExecutor( } } - executor_.start(executorConfig.maxParallelOps); + executor_.start(static_cast(executorConfig.maxParallelOps)); } std::vector ParallelGraphExecutor::execute( diff --git a/torch/nativert/executor/ParallelGraphExecutor.h b/torch/nativert/executor/ParallelGraphExecutor.h index 747e6993770a..1810ffb3b7b1 100644 --- a/torch/nativert/executor/ParallelGraphExecutor.h +++ b/torch/nativert/executor/ParallelGraphExecutor.h @@ -46,8 +46,8 @@ class ThreadPoolExecutor { void add(SessionState* session, WorkUnit* unit); void add( SessionState* session, - std::vector::const_iterator&& begin, - const std::vector::const_iterator&& end); + std::vector::const_iterator begin, + const std::vector::const_iterator& end); C10_ALWAYS_INLINE moodycamel::ProducerToken& ptok(); C10_ALWAYS_INLINE moodycamel::ConsumerToken& ctok(); diff --git a/torch/nativert/executor/Placement.cpp b/torch/nativert/executor/Placement.cpp index be8b6e6df966..0432ecdc2a7c 100644 --- a/torch/nativert/executor/Placement.cpp +++ b/torch/nativert/executor/Placement.cpp @@ -32,6 +32,15 @@ std::ostream& operator<<(std::ostream& os, const Placement& placement) { return os; } +namespace { +void assertCudaDeviceHasIndex(const c10::Device& device) { + if (device.is_cuda()) { + TORCH_CHECK( + device.has_index(), "CUDA device in placement must have an index"); + } +} +} // namespace + Placement::Placement(std::optional defaultDevice) : Placement({}, defaultDevice) {} @@ -39,16 +48,20 @@ Placement::Placement( const std::unordered_map& deviceMap, std::optional defaultDevice) { for (const auto& [srcDevice, dstDevice] : deviceMap) { - deviceMap_.try_emplace( - normalizeDevice(srcDevice), normalizeDevice(dstDevice)); + assertCudaDeviceHasIndex(srcDevice); + assertCudaDeviceHasIndex(dstDevice); + + deviceMap_.try_emplace(srcDevice, dstDevice); } + if (defaultDevice.has_value()) { - defaultDevice_ = normalizeDevice(defaultDevice.value()); + assertCudaDeviceHasIndex(defaultDevice.value()); + defaultDevice_ = defaultDevice.value(); } } c10::Device Placement::getMappedDevice(const c10::Device& srcDevice) const { - auto it = deviceMap_.find(normalizeDevice(srcDevice)); + auto it = deviceMap_.find(srcDevice); if (it != deviceMap_.end()) { return it->second; } diff --git a/torch/nativert/executor/Placement.h b/torch/nativert/executor/Placement.h index 9f9a2c627d25..6ea86348973e 100644 --- a/torch/nativert/executor/Placement.h +++ b/torch/nativert/executor/Placement.h @@ -8,21 +8,6 @@ namespace torch::nativert { -/** - * This function returns a normalized version of the input device: - * - For CPU devices, the returned device will have no index (i.e., the default - * CPU device). - * - For CUDA devices, if no index is specified, index 0 is assumed. - * - For other device types, the function will raise an error. - * - * @param device The input c10::Device to normalize. - * @return A normalized c10::Device with standardized indexing. - * - * @throws c10::Error If the device type is not CPU or CUDA. - */ - -c10::Device normalizeDevice(const c10::Device& device); - /** * Returns true if the two devices are the same and has the same device index * (if cuda). diff --git a/torch/nativert/executor/PlacementUtils.cpp b/torch/nativert/executor/PlacementUtils.cpp index 988c9997ed03..e73224b4f4f5 100644 --- a/torch/nativert/executor/PlacementUtils.cpp +++ b/torch/nativert/executor/PlacementUtils.cpp @@ -4,20 +4,6 @@ namespace torch::nativert { -c10::Device normalizeDevice(const c10::Device& device) { - // cpu device doesn't have index - // cuda device index must have a index - if (device.is_cpu()) { - return c10::Device(c10::DeviceType::CPU); - } else if (device.is_cuda()) { - return c10::Device( - c10::DeviceType::CUDA, - device.has_index() ? device.index() : static_cast(0)); - } else { - TORCH_CHECK(false, "Unsupported device type", device); - } -} - bool isSameDevice(const c10::Device& a, const c10::Device& b) { if (a.is_cpu()) { return b.is_cpu(); diff --git a/torch/nativert/executor/SerialGraphExecutor.cpp b/torch/nativert/executor/SerialGraphExecutor.cpp index 017f4f178c8b..58a7cd1c4307 100644 --- a/torch/nativert/executor/SerialGraphExecutor.cpp +++ b/torch/nativert/executor/SerialGraphExecutor.cpp @@ -14,11 +14,17 @@ std::vector SerialGraphExecutor::execute( std::vector SerialGraphExecutor::executeWithPrefilledFrame( ExecutionFrame& executionFrame) { - executionFrame.withMemoryPlanner([&]() { + executionFrame.withManagedMemory([&](const LayoutManager* layout_manager) { // Execute kernels for all nodes except prim.Input and prim.Output for (NodeIndex nodeIdx = 1; nodeIdx < nodeKernels_.size() - 1; ++nodeIdx) { nodeKernels_[nodeIdx]->compute(executionFrame); +#ifndef NDEBUG + if (layout_manager != nullptr) { + layout_manager->assert_no_overlapping_storages(nodeIdx); + } +#endif + // don't free intermediate values when static memory planning is enabled if (executorConfig_.tryFreeUnmanagedValuesAfterUse) { // Free the intermediate values that are no used anymore diff --git a/torch/nativert/executor/memory/AliasAnalyzer.cpp b/torch/nativert/executor/memory/AliasAnalyzer.cpp index e56eb4085316..86de7bc3d6fb 100644 --- a/torch/nativert/executor/memory/AliasAnalyzer.cpp +++ b/torch/nativert/executor/memory/AliasAnalyzer.cpp @@ -23,18 +23,32 @@ AliasAnalyzer::AliasAnalyzer( maybe_update_aliases_from_schema(node, schemas); } + maybe_extend_lifetimes(graph); + + // squash_deep_aliases this will populate aliases_ + // with a mapping from each alias to its backed + // source (i.e., the value that owns the underlying + // dataptr for said alias) + squash_deep_aliases(graph); + // set all non-aliasing outputs. outputs // that are aliased will be set later when // lifetimes are extended for (const auto* output : graph.outputs()) { if (!is_alias(output)) { - values_associated_with_outputs_.insert(output); + values_associated_with_outputs_.emplace(output); } } - maybe_extend_lifetimes(graph); log_state(); -} + + alive_values_at_time_.resize(graph.nodes().size()); + for (const auto& [v, lifetime] : lifetimes_) { + for (const auto t : c10::irange(lifetime.start, lifetime.end + 1)) { + alive_values_at_time_[t].emplace_back(v); + } + } +} // namespace torch::nativert bool /* applied */ AliasAnalyzer::update_aliases_if_packed_listunpack( const Node& node, @@ -52,18 +66,18 @@ bool /* applied */ AliasAnalyzer::update_aliases_if_packed_listunpack( } const auto& list_elems = list->getListElements(); - TORCH_CHECK_EQ(list_elems.size(), node.numOutputs()); + TORCH_CHECK(list_elems.size() == node.numOutputs()); for (const auto j : c10::irange(node.numOutputs())) { const Value* input = list_elems.at(j); const Value* output = node.outputs().at(j); - TORCH_CHECK_NE(input, output); + TORCH_CHECK(input != output); create_or_update_lifetime(input, i); create_or_update_lifetime(output, i); - aliases_[output].insert(input); + aliases_[output].emplace(input); } return true; @@ -96,7 +110,7 @@ void AliasAnalyzer::maybe_update_aliases_from_schema( VLOG(1) << node.target() << " may contain input/output alias: " << input->id() << " -> " << output->id(); - aliases_[output].insert(input); + aliases_[output].emplace(input); } } } @@ -109,6 +123,56 @@ void AliasAnalyzer::create_or_update_lifetime(const Value* value, size_t i) { } } +void AliasAnalyzer::squash_deep_aliases(const Graph& graph) { + for (auto& node : graph.nodes()) { + for (const auto& output : node.outputs()) { + auto aliasIt = aliases_.find(output); + if (aliasIt == aliases_.end()) { + continue; + } + + c10::FastSet filtered_srcs; + + auto& srcs = aliasIt->second; + for (const auto* src : srcs) { + // check if this source is an alias itself, + // making 'output' a deep alias (i.e., + // an alias of an alias) + + // we want aliases_[x] to return the value from which x + // inherits its dataptr. + // as such, we want to add values that do not meet this + // criteria (i.e., those that are aliases). + // in practice, there can only be 1 value that meets this + // criteria (at a time), but there are some cases where + // this is ambiguous (e.g., where the spec doesn't exist, + // dealing with variadics) + auto srcAliasIt = aliases_.find(src); + if (srcAliasIt == aliases_.end()) { + filtered_srcs.emplace(src); + continue; + } + + // since we are going from the beginning of the graph + // to the end of the graph we can assume that these + // aliases, which have already been visited, have already + // been squashed. + auto& srcs_of_src = srcAliasIt->second; + for (const auto* src_of_src : srcs_of_src) { + // if the source of the source is not an alias + // (i.e., it has ownership over it's data ptr) + // then we want to add it as a source of 'output' + if (aliases_.find(src_of_src) == aliases_.end()) { + filtered_srcs.emplace(src_of_src); + } + } + } + + srcs = std::move(filtered_srcs); + } + } +} + void AliasAnalyzer::maybe_extend_lifetimes(const Graph& graph) { c10::FastSet extended; @@ -129,10 +193,11 @@ void AliasAnalyzer::maybe_extend_lifetimes(const Graph& graph) { VLOG(1) << "extended EOL of value " << src->id() << " to " << eol; - extended.insert(src); + extended.emplace(src); - if (eol == graph.nodes().size() - 1 /* aliases output */) { - values_associated_with_outputs_.insert(src); + if (aliases_.find(src) == aliases_.end() && + eol == graph.nodes().size() - 1 /* aliases output */) { + values_associated_with_outputs_.emplace(src); } } } diff --git a/torch/nativert/executor/memory/AliasAnalyzer.h b/torch/nativert/executor/memory/AliasAnalyzer.h index c9784d5d84ab..4b0d827453b0 100644 --- a/torch/nativert/executor/memory/AliasAnalyzer.h +++ b/torch/nativert/executor/memory/AliasAnalyzer.h @@ -14,26 +14,38 @@ class AliasAnalyzer { const Graph& graph, const c10::FastMap& schemas); - C10_ALWAYS_INLINE const AllocationLifetime& lifetime( + const c10::FastSet* get_sources_of_alias( const Value* value) const { + const auto it = aliases_.find(value); + if (it == aliases_.end()) { + return nullptr; + } + return &it->second; + } + + const AllocationLifetime& lifetime(const Value* value) const { return lifetimes_.at(value); } - C10_ALWAYS_INLINE bool is_alias(const Value* value) const { + bool is_alias(const Value* value) const { return aliases_.find(value) != aliases_.end(); } - C10_ALWAYS_INLINE bool is_storage_associated_with_output( - const Value* value) const { + bool is_storage_associated_with_output(const Value* value) const { return values_associated_with_outputs_.find(value) != values_associated_with_outputs_.end(); } - C10_ALWAYS_INLINE const c10::FastSet& - values_associated_with_output_storage() const { + const c10::FastSet& values_associated_with_output_storage() + const { return values_associated_with_outputs_; } + const std::vector& alive_values_at_time(size_t time) const { + TORCH_CHECK(time < alive_values_at_time_.size()); + return alive_values_at_time_[time]; + } + private: // listunpack operations who take a list that has // been created with a listpack operation should @@ -72,14 +84,35 @@ class AliasAnalyzer { // even if they aren't explicitly considered outputs) void maybe_extend_lifetimes(const Graph& graph); + // in the event that we have aliases-of-aliases + // we want to make sure that the 'sources' + // are propagated + // + // e.g., + // %x0 = ... + // %x1 = some_aliasing_op(x0) + // %x2 = some_aliasing_op(x1) + // + // we want aliases_[x2] = x0 + // instead of aliases[x2] = x1 + // + // the result is aliases_ will contain a + // mapping from each alias to its backed + // source (i.e., the value that owns its + // associated dataptr) + void squash_deep_aliases(const Graph& graph); + void log_state() const; - // mapping from alias to the set of values that it aliases + // mapping from alias to its source c10::FastMap> aliases_; c10::FastMap lifetimes_; // non-aliasing outputs or non-aliasing intermediates that are aliased by // outputs c10::FastSet values_associated_with_outputs_; + // alive_values_at_time_[i] = values that are "alive" during the + // computation of node i + std::vector> alive_values_at_time_; }; } // namespace torch::nativert diff --git a/torch/nativert/executor/memory/LayoutManager.cpp b/torch/nativert/executor/memory/LayoutManager.cpp index 7b5062d7993f..827e8cd05781 100644 --- a/torch/nativert/executor/memory/LayoutManager.cpp +++ b/torch/nativert/executor/memory/LayoutManager.cpp @@ -4,6 +4,7 @@ #include #include +#include namespace torch::nativert { @@ -139,14 +140,17 @@ void LayoutManager::ensure_managed_storages(bool allocate) { } void LayoutManager::populate_tensor_values() { - CHECK(planned_tensors_.empty()); - CHECK(unplanned_ivalues_.empty()); + TORCH_CHECK(planned_tensors_.empty()); + TORCH_CHECK(unplanned_ivalues_.empty()); const auto& value_ids = planner_.get_planned_values(); planned_tensors_.resize(value_ids.size()); planned_tensors_max_nbytes_local_.resize(value_ids.size()); for (const auto&& [i, v] : c10::enumerate(value_ids)) { +#ifndef NDEBUG + value_to_vector_idx_map_[v] = i; +#endif planned_tensors_[i] = &parent_frame_.getIValue(v).toTensor(); } @@ -157,6 +161,165 @@ void LayoutManager::populate_tensor_values() { } } +#ifndef NDEBUG +void LayoutManager::assert_no_overlapping_storages( + size_t graph_node_idx) const { + if (state_ != LayoutManagerState::Running) { + return; + } + + /* + for each value + (either an input or output) + ensure that the associated storage + slice lies within the allocated slice + if it is managed (or if it is an alias, + we can use the slice allocated to its source) + --- + also ensure that the current index lies + within the lifetime of this value + */ + + const auto& alias_analyzer = planner_.get_alias_analyzer(); + // get the 'active' values during the execution of nodes[graph_node_idx] + const auto& alive_values = + alias_analyzer.alive_values_at_time(graph_node_idx); + + // make sure active memory intervals are non-overlapping + // by sorting them by start, and ensuring + // cur.start > prev.end for each + // + // by default, the pairs are compared lexicographically. + // ref: https://cplusplus.com/reference/utility/pair/operators/ + // + // in our case, this means that leftmost (on the number line) intervals will + // come first, and if the start point of two intervals is the same, they will + // be sorted by their relative widths (in increasing order) + // + // e.g., the ordering for the following usage intervals + // + // |######1######| + // |######2######| + // |######3#####| + // + // would be [1,3,2] + + std::multiset> intervals; + + planner_.with_plan([&](const LayoutPlan& plan) { + // prevent recomputation from occurring + c10::FastSet checked_values; + + // check that some arbitrary storage (defined by the allocation start and + // the size in bytes) lies within the slice allocated for value_id during + // planning. + // + // if the checks pass, add the interval [alloc_start, alloc_start + + // alloc_nbytes) to the set of intervals + auto check_allocation_bounds = + [&](ValueId value_id, size_t alloc_start, size_t alloc_end) -> void { + if (!checked_values.emplace(value_id).second /* already checked */) { + return; + } + auto& alloc = plan.allocations[value_to_vector_idx_map_.at(value_id)]; + TORCH_CHECK(alloc_start >= alloc.offset); + TORCH_CHECK(alloc_end < alloc.offset + alloc.size); + intervals.emplace(alloc_start, alloc_end); + }; + + // get the inclusive storage interval for some value (i.e., + // [buffer_storage_start_offset, buffer_storage_start_offset + + // storage_nbytes]) that represents the sub-slice of the runtime-managed + // buffer allocated to this tensor + auto try_get_interval = + [&](ValueId value_id) -> std::optional> { + const auto& iv = parent_frame_.getIValue(value_id); + if (!iv.isTensor()) { + return std::nullopt; + } + + const auto& storage_impl = iv.toTensor().storage().unsafeGetStorageImpl(); + const auto storage_nbytes = storage_impl->nbytes(); + + if (const auto start = layout_buffer_.get_offset_from_ptr( + storage_impl->data_ptr().get()); + start.has_value()) { + return std::make_pair(*start, *start + storage_nbytes - 1); + } + + return std::nullopt; + }; + + for (auto v : alive_values) { + // sanity check lifetimes to ensure this + // value ~should~ be alive at this point + const auto& lt = alias_analyzer.lifetime(v); + TORCH_CHECK(graph_node_idx >= lt.start); + TORCH_CHECK(graph_node_idx <= lt.end); + + const auto interval = try_get_interval(v->id()); + if (C10_UNLIKELY(!interval.has_value())) { + continue; + } + + auto& [v_start, v_end] = *interval; + + // it's possible that v is an alias, in which case + // we want to try to get the source (i.e., the value) + // that actually owns the storage + // + // NOTE: it's possible the source is ambiguous, hence + // why get_sources_of_alias returns a set (although it's usually a + // singleton set) + if (const auto* srcs_of_v = alias_analyzer.get_sources_of_alias(v); + srcs_of_v != nullptr /* v is an alias */) { + // 1. v's interval is a sub-interval of ~a~ source's interval and we + // want to add the source's interval to the set of intervals + // 2. v possibly got re-alloc'd / is not actually aliasing anything + // and we want to add v's interval to the set of intervals + bool found_viable_source = false; + + for (const auto* src_of_v : *srcs_of_v) { + const auto src_interval = try_get_interval(src_of_v->id()); + if (C10_UNLIKELY(!src_interval.has_value())) { + continue; + } + + auto& [src_of_v_start, src_of_v_end] = *src_interval; + + if (v_start >= src_of_v_start && v_end <= src_of_v_end) { + check_allocation_bounds( + src_of_v->id(), src_of_v_start, src_of_v_end); + found_viable_source = true; + break; + } + } + + if (!found_viable_source) { + check_allocation_bounds(v->id(), v_start, v_end); + } + } else /* if v isn't an alias */ { + check_allocation_bounds(v->id(), v_start, v_end); + } + } + }); + + // if we only have less than two active intervals, + // it isn't possible to have overlap... + if (intervals.size() < 2) { + return; + } + + // ensure that no 'active' buffer intervals are overlapping + auto it = intervals.begin(); + size_t prev_end = it->second; + while (++it != intervals.end()) { + TORCH_CHECK(prev_end < it->first /* cur_start */); + prev_end = it->second; + } +} +#endif + void LayoutManager::try_update_historical_max_nbytes() { for (const auto i : c10::irange(planned_tensors_.size())) { auto nbytes = get_aligned_nbytes(planned_tensors_[i]->nbytes()); diff --git a/torch/nativert/executor/memory/LayoutManager.h b/torch/nativert/executor/memory/LayoutManager.h index 76f658e09d08..d98700e7f021 100644 --- a/torch/nativert/executor/memory/LayoutManager.h +++ b/torch/nativert/executor/memory/LayoutManager.h @@ -24,10 +24,24 @@ struct ContiguousLayoutBuffer { ContiguousLayoutBuffer& operator=(const ContiguousLayoutBuffer& other) = delete; + std::optional get_offset_from_ptr(void* offset_ptr) const { + void* raw_ptr = data_ptr_.get(); + if (!raw_ptr || !offset_ptr) { + return std::nullopt; + } + + auto offset = reinterpret_cast(offset_ptr) - + reinterpret_cast(raw_ptr); + + return offset < 0 || static_cast(offset) >= size_ + ? std::nullopt + : std::optional(offset); + } + void* get_ptr_with_offset(size_t offset) { void* raw_ptr = data_ptr_.get(); - TORCH_CHECK_NOTNULL(raw_ptr); - TORCH_CHECK_LE(offset, size_); + TORCH_CHECK(raw_ptr != nullptr); + TORCH_CHECK(offset <= size_); return reinterpret_cast( reinterpret_cast(raw_ptr) + offset); } @@ -47,7 +61,7 @@ struct ContiguousLayoutBuffer { void clear(size_t size) { VLOG(1) << "clearing first " << size << "bytes of layout buffer of size " << size_; - TORCH_CHECK_LE(size, size_); + TORCH_CHECK(size <= size_); std::memset(data_ptr_.get(), 0, size); } @@ -112,8 +126,8 @@ struct ContiguousStorageImplBuffer { } c10::StorageImpl& at(size_t i) { - TORCH_CHECK_LT(i, size_) - << "requested storage index " << i << " out of bounds " << size_; + TORCH_CHECK( + i < size_, "requested storage index ", i, " out of bounds ", size_); return buffer_[i]; } @@ -124,7 +138,7 @@ struct ContiguousStorageImplBuffer { } c10::StorageImpl& to_managed(at::StorageImpl& s) { - TORCH_CHECK_LT(size_, capacity_); + TORCH_CHECK(size_ < capacity_); return *(new (&buffer_[size_++]) at::StorageImpl( at::StorageImpl::use_byte_size_t(), static_cast(s.nbytes()), @@ -148,10 +162,32 @@ class LayoutManager { torch::nativert::LayoutManagerSettings settings = {}); ~LayoutManager() = default; +// this is a debugging function. it will slow thing down SIGNIFICANTLY +// so please ensure this isn't called unless you really need it +// +// it checks a few things in between node executions... +// +// 1. ensures all 'alive' values are within the bounds of their lifetimes +// - this is the definition of a sanity check since the live-sets are built +// from the lifetimes lol. if this fails, something is very very wrong +// 2. ensures that all planned values are within the bounds of their +// allocated storage buffer slices +// - if the value is an alias, ensure the alias is within the bounds +// of the source value +// 3. ensures that all planned value data-ptrs are non-overlapping +#ifndef NDEBUG + void assert_no_overlapping_storages( + size_t + graph_node_idx /* the graph node that is currently being computed */) + const; +#endif + + private: + friend class LayoutManagerGuard; + void allocate(); void deallocate_and_plan(); - private: #ifdef LayoutPlannerTests_TEST_FRIENDS LayoutPlannerTests_TEST_FRIENDS; #endif @@ -178,6 +214,9 @@ class LayoutManager { std::vector planned_tensors_; std::vector planned_tensors_max_nbytes_local_; +#ifndef NDEBUG + c10::FastMap value_to_vector_idx_map_; +#endif ContiguousLayoutBuffer layout_buffer_; ContiguousStorageImplBuffer storage_impl_buffer_; diff --git a/torch/nativert/executor/memory/LayoutPlanner.cpp b/torch/nativert/executor/memory/LayoutPlanner.cpp index 5c45a08ea6f1..ead887bbe470 100644 --- a/torch/nativert/executor/memory/LayoutPlanner.cpp +++ b/torch/nativert/executor/memory/LayoutPlanner.cpp @@ -16,9 +16,18 @@ LayoutPlanner::LayoutPlanner( const c10::FastMap& kernelSchemas, const std::vector& persistentValues, const torch::nativert::LayoutPlannerSettings& settings) - : managed_values_(graph.values().size()), settings_(settings) { - auto value_to_allocation_spec = c10::FastMap{}; + : managed_values_(graph.values().size()), +#ifndef NDEBUG + alias_analyzer_(graph, kernelSchemas), +#endif + settings_(settings) { +#ifndef NDEBUG + auto& alias_analyzer = alias_analyzer_; +#else auto alias_analyzer = AliasAnalyzer(graph, kernelSchemas); +#endif + + auto value_to_allocation_spec = c10::FastMap{}; std::set input_values_set_; for (const auto* nv : graph.userInputs()) { @@ -124,7 +133,7 @@ LayoutPlanner::LayoutPlanner( } } - TORCH_CHECK_NOTNULL(algorithm_); + TORCH_CHECK(algorithm_ != nullptr, "algorithm can't be null"); initialize_vectors(value_to_allocation_spec); @@ -150,7 +159,9 @@ void LayoutPlanner::initialize_vectors( size_t i = 0; for (auto& [v, spec] : value_to_allocation_spec) { - TORCH_CHECK_LE(spec.lifetime.start, spec.lifetime.end); + TORCH_CHECK( + spec.lifetime.start <= spec.lifetime.end, + "lifetime start must be before lifetime end"); planned_values_[i] = v->id(); planned_values_historical_max_nbytes_[i] = spec.size; diff --git a/torch/nativert/executor/memory/LayoutPlanner.h b/torch/nativert/executor/memory/LayoutPlanner.h index 6382fdbba01b..10dcf906bef3 100644 --- a/torch/nativert/executor/memory/LayoutPlanner.h +++ b/torch/nativert/executor/memory/LayoutPlanner.h @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -61,8 +62,18 @@ class LayoutPlanner { const std::vector& get_planned_values() const; const std::vector& get_unplanned_values() const; - C10_ALWAYS_INLINE bool is_managed(ValueId id) { - TORCH_CHECK_LT(static_cast(id), managed_values_.size()); +#ifndef NDEBUG + const AliasAnalyzer& get_alias_analyzer() const { + return alias_analyzer_; + } +#endif + + size_t num_values() const { + return managed_values_.size(); + } + + bool is_managed(ValueId id) { + TORCH_CHECK(static_cast(id) < managed_values_.size()); return managed_values_[id]; } @@ -120,6 +131,9 @@ class LayoutPlanner { LayoutPlannerAlgorithm* algorithm_; c10::LeftRight plan_; +#ifndef NDEBUG + AliasAnalyzer alias_analyzer_; +#endif torch::nativert::LayoutPlannerSettings settings_; }; diff --git a/torch/nativert/graph/Graph.cpp b/torch/nativert/graph/Graph.cpp index 3cc7f678fcff..bce01f278a57 100644 --- a/torch/nativert/graph/Graph.cpp +++ b/torch/nativert/graph/Graph.cpp @@ -568,7 +568,7 @@ void Graph::lint() const { } } for (const auto& node : nodes()) { - TORCH_CHECK_EQ(node.owningGraph(), this); + TORCH_CHECK(node.owningGraph() == this); } // Check that every list type is either produced by a prim.ListPack or // immediately consumed by a prim.ListUnpack. We make use of this invariant @@ -668,7 +668,7 @@ void Graph::applyDevicePlacement(const Placement& placement) { } Node* Graph::nodeAfter(Node* n) { - TORCH_CHECK_EQ(n->owningGraph(), this); + TORCH_CHECK(n->owningGraph() == this); if (n == outputNode_) { return nullptr; } @@ -677,7 +677,7 @@ Node* Graph::nodeAfter(Node* n) { } const Node* Graph::nodeAfter(const Node* n) const { - TORCH_CHECK_EQ(n->owningGraph(), this); + TORCH_CHECK(n->owningGraph() == this); if (n == outputNode_) { return nullptr; } @@ -686,7 +686,7 @@ const Node* Graph::nodeAfter(const Node* n) const { } Node* Graph::nodeBefore(Node* n) { - TORCH_CHECK_EQ(n->owningGraph(), this); + TORCH_CHECK(n->owningGraph() == this); if (n == inputNode_) { return nullptr; } @@ -695,7 +695,7 @@ Node* Graph::nodeBefore(Node* n) { } const Node* Graph::nodeBefore(const Node* n) const { - TORCH_CHECK_EQ(n->owningGraph(), this); + TORCH_CHECK(n->owningGraph() == this); if (n == inputNode_) { return nullptr; } @@ -704,8 +704,7 @@ const Node* Graph::nodeBefore(const Node* n) const { } void Graph::removeNode(Node* n) { - TORCH_CHECK_EQ(n->owningGraph(), this) - << "Node does not belong to this graph!"; + TORCH_CHECK(n->owningGraph() == this, "Node does not belong to this graph!"); for (auto* outputVal : n->outputs()) { TORCH_CHECK( @@ -747,8 +746,7 @@ std::vector Graph::insertGraph( const Graph& subgraph, std::vector inputs, std::unordered_map& valueMap) { - TORCH_CHECK_EQ(subgraph.inputs().size(), inputs.size()) - << "Input size mismatch"; + TORCH_CHECK(subgraph.inputs().size() == inputs.size(), "Input size mismatch"); for (auto i : c10::irange(subgraph.inputs().size())) { valueMap[subgraph.inputs()[i]] = inputs[i]; } @@ -854,7 +852,7 @@ void Node::addOutput() { } Value* Node::addOutput(const Type& type) { - TORCH_CHECK_EQ(type, Type::Kind::None); + TORCH_CHECK(type == Type::Kind::None); Value* v = owningGraph_->addValue(std::nullopt, type, this); outputs_.push_back(v); return v; @@ -893,9 +891,9 @@ std::vector Value::getListElements() const { ret.push_back(tv.value); } } else { - TORCH_CHECK_EQ(users().size(), 1); + TORCH_CHECK(users().size() == 1); const auto listUnpack = users()[0]; - TORCH_CHECK_EQ(listUnpack->target(), "prim.ListUnpack"); + TORCH_CHECK(listUnpack->target() == "prim.ListUnpack"); for (const auto v : listUnpack->outputs()) { ret.push_back(v); } @@ -1070,17 +1068,17 @@ std::ostream& operator<<(std::ostream& out, const Graph& graph) { c10::Device convertDevice(std::string_view symbol) { // Symbol looks like `Device{cuda:1}` const auto typeStart = symbol.find('{') + 1; - TORCH_CHECK_LT(typeStart, symbol.size()); + TORCH_CHECK(typeStart < symbol.size()); const auto typeEnd = symbol.find(':'); - TORCH_CHECK_NE(typeEnd, std::string_view::npos); + TORCH_CHECK(typeEnd != std::string_view::npos); const auto type = symbol.substr(typeStart, typeEnd - typeStart); const auto indexStart = typeEnd + 1; - TORCH_CHECK_LT(indexStart, symbol.size()); + TORCH_CHECK(indexStart < symbol.size()); const auto indexEnd = symbol.find('}'); - TORCH_CHECK_NE(indexEnd, std::string_view::npos); + TORCH_CHECK(indexEnd != std::string_view::npos); const auto index = symbol.substr(indexStart, indexEnd - indexStart); @@ -1099,7 +1097,7 @@ c10::Device convertDevice(std::string_view symbol) { Constant convertAtomicConstant(std::string_view symbol) { if (c10::starts_with(symbol, "\"")) { // chop off the outer quotes and return the string - TORCH_CHECK_GE(symbol.size(), 2); + TORCH_CHECK(symbol.size() >= 2); symbol.remove_prefix(1); symbol.remove_suffix(1); return std::string(symbol); @@ -1178,8 +1176,8 @@ Constant convertListConstant(std::string_view source) { TORCH_CHECK(false, "constant lists only support int, float, bool"); } } else { - TORCH_CHECK_EQ(type.index(), val.index()) - << "lists must have all the same type"; + TORCH_CHECK( + type.index() == val.index(), "lists must have all the same type"); } values.push_back(std::move(val)); if (source.at(curPos) == ']') { @@ -1306,7 +1304,7 @@ bool Parser::nextIf(char expected) { } void Parser::parseGraphInputs() { - TORCH_CHECK_EQ(curPos_, 0); + TORCH_CHECK(curPos_ == 0); expect("graph"); const auto inputs = parseList( '(', ')', [&]() { return parseAtomicSymbol(); }); diff --git a/torch/nativert/graph/GraphPasses.cpp b/torch/nativert/graph/GraphPasses.cpp index 327f32185e91..981a63815db2 100644 --- a/torch/nativert/graph/GraphPasses.cpp +++ b/torch/nativert/graph/GraphPasses.cpp @@ -101,7 +101,7 @@ std::string selectScalarOverloadName(const Node& node) { "floor_divide_out", "_conj"}; std::vector atoms = c10::split(node.target(), '.'); - TORCH_CHECK_GE(atoms.size(), 3); + TORCH_CHECK(atoms.size() >= 3); std::string ns = std::string{atoms[atoms.size() - 3]}; std::string opName = std::string{atoms[atoms.size() - 2]}; diff --git a/torch/nativert/graph/Serialization.cpp b/torch/nativert/graph/Serialization.cpp index d32e7fe72843..4c45edd1f575 100644 --- a/torch/nativert/graph/Serialization.cpp +++ b/torch/nativert/graph/Serialization.cpp @@ -422,9 +422,11 @@ std::unique_ptr jsonToSubgraph( } auto it = jsonTensorValue.find(inputName); - CHECK(it != jsonTensorValue.end()) - << "Missing tensor metadata for " << inputName - << "in thriftGraph.tensorValue"; + TORCH_CHECK( + it != jsonTensorValue.end(), + "Missing tensor metadata for ", + inputName, + "in thriftGraph.tensorValue"); weightsTensorMeta[weightName] = it->second; } graph->setWeightsMeta(weightsTensorMeta); diff --git a/torch/nativert/graph/TensorMeta.cpp b/torch/nativert/graph/TensorMeta.cpp index d7d83710a5a3..97afbc9f095e 100644 --- a/torch/nativert/graph/TensorMeta.cpp +++ b/torch/nativert/graph/TensorMeta.cpp @@ -41,6 +41,10 @@ c10::ScalarType convertJsonScalarType( return c10::ScalarType::Float8_e4m3fn; case torch::_export::ScalarType::FLOAT8E5M2: return c10::ScalarType::Float8_e5m2; + case torch::_export::ScalarType::FLOAT8E4M3FNUZ: + return c10::ScalarType::Float8_e4m3fnuz; + case torch::_export::ScalarType::FLOAT8E5M2FNUZ: + return c10::ScalarType::Float8_e5m2fnuz; default: TORCH_CHECK(false, "unknown scalar type", static_cast(scalarType)); } @@ -106,7 +110,7 @@ TensorMeta::TensorMeta(const torch::_export::TensorMeta& tensorMeta) torch::_export::SymInt::Tag::AS_INT) { storage_offset_ = tensorMeta.get_storage_offset().get_as_int(); } else { - CHECK(false) << "SymInt not supported yet"; + TORCH_CHECK(false, "SymInt not supported yet"); } for (const auto& size : tensorMeta.get_sizes()) { diff --git a/torch/nativert/graph/TensorMeta.h b/torch/nativert/graph/TensorMeta.h index 5b0c90474a09..585383a95b5f 100644 --- a/torch/nativert/graph/TensorMeta.h +++ b/torch/nativert/graph/TensorMeta.h @@ -25,12 +25,12 @@ class TensorMeta { explicit TensorMeta(const torch::_export::TensorMeta& tensorMeta); c10::IntArrayRef sizes() const { - CHECK(!hasSymbolicShape_) << "TensorMeta has symbolic shape"; + TORCH_CHECK(!hasSymbolicShape_, "TensorMeta has symbolic shape"); return sizes_; } c10::IntArrayRef strides() const { - CHECK(!hasSymbolicShape_) << "TensorMeta has symbolic shape"; + TORCH_CHECK(!hasSymbolicShape_, "TensorMeta has symbolic shape"); return strides_; } @@ -55,7 +55,7 @@ class TensorMeta { } int64_t numel() const { - CHECK(!hasSymbolicShape_) << "TensorMeta has symbolic shape"; + TORCH_CHECK(!hasSymbolicShape_, "TensorMeta has symbolic shape"); return numel_; } diff --git a/torch/nativert/kernels/AutoFunctionalizeKernel.cpp b/torch/nativert/kernels/AutoFunctionalizeKernel.cpp index cbbd502d8215..76589b52c56e 100644 --- a/torch/nativert/kernels/AutoFunctionalizeKernel.cpp +++ b/torch/nativert/kernels/AutoFunctionalizeKernel.cpp @@ -11,15 +11,14 @@ UnsafeAutoFunctionalizeKernel::UnsafeAutoFunctionalizeKernel(const Node* node) op_(getOperatorForTarget( std::get(node->attributes()[0].value))), schema_(op_.schema()), - arguments_(prefillStackWithStaticArgs(node, schema_)) { + arguments_(prefillStackWithStaticArgs(node, schema_)), + numOutputs_(static_cast(schema_.returns().size())) { for (const auto& [idx, schemaArg] : c10::enumerate(schema_.arguments())) { if (schemaArg.alias_info() != nullptr && schemaArg.alias_info()->isWrite()) { mutatingInputArgs_.push_back(node->getInput(schemaArg.name()).value); } } - - numOutputs_ = schema_.returns().size(); } void UnsafeAutoFunctionalizeKernel::computeInternal( diff --git a/torch/nativert/kernels/C10Kernel.cpp b/torch/nativert/kernels/C10Kernel.cpp index 450042e7c92d..3c207e5708a3 100644 --- a/torch/nativert/kernels/C10Kernel.cpp +++ b/torch/nativert/kernels/C10Kernel.cpp @@ -49,8 +49,10 @@ void C10Kernel::computeInternal(ExecutionFrame& executionFrame) const { // these are named I don't think it will ever happen in practice. We need to // enforce it though. const auto& outputValues = node_->outputs(); - TORCH_CHECK_EQ(outputValues.size(), stack.size()) - << "Output size mismatch for " << node_->toString(); + TORCH_CHECK( + outputValues.size() == stack.size(), + "Output size mismatch for ", + node_->toString()); for (auto&& [i, actualOutput] : c10::enumerate(stack)) { executionFrame.setIValue(outputValues[i]->id(), std::move(actualOutput)); } diff --git a/torch/nativert/kernels/CallTorchBindKernel.cpp b/torch/nativert/kernels/CallTorchBindKernel.cpp index 5e8c9cf6be75..c3643cbce1da 100644 --- a/torch/nativert/kernels/CallTorchBindKernel.cpp +++ b/torch/nativert/kernels/CallTorchBindKernel.cpp @@ -8,7 +8,7 @@ namespace torch::nativert { CallTorchBindKernel::CallTorchBindKernel(const Node* node) : OpKernel(node) { const Value* customObjValue = node_->inputs()[0].value; - CHECK(customObjValue->type() == Type::Kind::CustomObj); + TORCH_CHECK(customObjValue->type() == Type::Kind::CustomObj); customClassName_ = customObjValue->type().classFqn(); customClassType_ = torch::jit::getCustomClass(customClassName_); @@ -16,16 +16,18 @@ CallTorchBindKernel::CallTorchBindKernel(const Node* node) : OpKernel(node) { // sample schema // torch.ops.higher_order.call_torchbind(arg1_1, 'add_tensor', arg0_1); - CHECK(node->attributes().size() == 1) - << "Expects higher_order.call_torchbind to only have a single attribute, methodName"; + TORCH_CHECK( + node->attributes().size() == 1, + "Expects higher_order.call_torchbind to only have a single attribute, methodName"); const auto& attr = node->attributes()[0]; - CHECK(std::holds_alternative(attr.value)) - << "method should be a string"; + TORCH_CHECK( + std::holds_alternative(attr.value), + "method should be a string"); methodName_ = std::get(attr.value); method_ = customClassType_->findMethod(methodName_); - CHECK(method_ != nullptr) << "method not found: " << methodName_; + TORCH_CHECK(method_ != nullptr, "method not found: ", methodName_); } void CallTorchBindKernel::computeInternal( @@ -42,7 +44,7 @@ void CallTorchBindKernel::computeInternal( // set outputs const auto& outputs = node_->outputs(); - TORCH_CHECK_EQ(outputs.size(), stack.size()); + TORCH_CHECK(outputs.size() == stack.size()); for (auto&& [i, outputValue] : c10::enumerate(stack)) { executionFrame.setIValue(outputs[i]->id(), std::move(outputValue)); } diff --git a/torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp b/torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp index c33fb81604f6..e8d7170fdf1c 100644 --- a/torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp +++ b/torch/nativert/kernels/GeneratedNativeStaticDispatchKernels.cpp @@ -39,7 +39,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::view_as_real(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.view_as_complex.default", @@ -48,31 +48,31 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::view_as_complex(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.real.default", aten_real_default, { const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::real(self); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.imag.default", aten_imag_default, { const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::imag(self); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten._conj.default", aten__conj_default, { const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::_conj(self); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.conj.default", aten_conj_default, { const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::conj(self); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.resolve_conj.default", @@ -81,7 +81,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::resolve_conj(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.resolve_neg.default", @@ -90,7 +90,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::resolve_neg(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten._neg_view.default", @@ -99,7 +99,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::_neg_view(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.diagonal.default", @@ -111,7 +111,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto dim2 = KernelInput(3).toInt(); KernelOutput(0) = at::native::diagonal(self, offset, dim1, dim2); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.linalg_diagonal.default", @@ -123,7 +123,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto dim2 = KernelInput(3).toInt(); KernelOutput(0) = at::native::linalg_diagonal(A, offset, dim1, dim2); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.expand_as.default", @@ -133,7 +133,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& other = KernelInput(1).toTensor(); KernelOutput(0) = at::native::expand_as(self, other); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.flatten.using_ints", @@ -144,7 +144,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto end_dim = KernelInput(2).toInt(); KernelOutput(0) = at::native::flatten(self, start_dim, end_dim); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.movedim.int", aten_movedim_int, { const auto& self = KernelInput(0).toTensor(); @@ -152,7 +152,7 @@ REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.movedim.int", aten_movedim_int, { const auto destination = KernelInput(2).toInt(); KernelOutput(0) = at::native::movedim(self, source, destination); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.moveaxis.int", aten_moveaxis_int, { const auto& self = KernelInput(0).toTensor(); @@ -160,7 +160,7 @@ REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.moveaxis.int", aten_moveaxis_int, { const auto destination = KernelInput(2).toInt(); KernelOutput(0) = at::native::moveaxis(self, source, destination); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.numpy_T.default", @@ -169,7 +169,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::numpy_T(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.matrix_H.default", @@ -178,19 +178,19 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::matrix_H(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.mT.default", aten_mT_default, { const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::mT(self); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.mH.default", aten_mH_default, { const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::mH(self); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.adjoint.default", @@ -199,13 +199,13 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::adjoint(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.ravel.default", aten_ravel_default, { const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::ravel(self); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.reshape_as.default", @@ -215,7 +215,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& other = KernelInput(1).toTensor(); KernelOutput(0) = at::native::reshape_as(self, other); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.detach.default", @@ -224,7 +224,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::detach(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.squeeze.default", @@ -233,20 +233,20 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::squeeze(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.squeeze.dim", aten_squeeze_dim, { const auto& self = KernelInput(0).toTensor(); const auto dim = KernelInput(1).toInt(); KernelOutput(0) = at::native::squeeze(self, dim); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.t.default", aten_t_default, { const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::t(self); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.transpose.int", aten_transpose_int, { const auto& self = KernelInput(0).toTensor(); @@ -254,7 +254,7 @@ REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.transpose.int", aten_transpose_int, { const auto dim1 = KernelInput(2).toInt(); KernelOutput(0) = at::native::transpose(self, dim0, dim1); return; -}); +}) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.unsqueeze.default", @@ -264,7 +264,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto dim = KernelInput(1).toInt(); KernelOutput(0) = at::native::unsqueeze(self, dim); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.view_as.default", @@ -274,7 +274,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& other = KernelInput(1).toTensor(); KernelOutput(0) = at::native::view_as(self, other); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.positive.default", @@ -283,7 +283,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::positive(self); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten._autocast_to_reduced_precision.default", @@ -297,7 +297,7 @@ REGISTER_NATIVE_CPU_KERNEL( KernelOutput(0) = at::native::_autocast_to_reduced_precision( self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten._autocast_to_full_precision.default", @@ -309,7 +309,7 @@ REGISTER_NATIVE_CPU_KERNEL( KernelOutput(0) = at::native::_autocast_to_full_precision( self, cuda_enabled, cpu_enabled); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.swapaxes.default", @@ -320,7 +320,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto axis1 = KernelInput(2).toInt(); KernelOutput(0) = at::native::swapaxes(self, axis0, axis1); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.swapdims.default", @@ -331,7 +331,7 @@ REGISTER_NATIVE_CPU_KERNEL( const auto dim1 = KernelInput(2).toInt(); KernelOutput(0) = at::native::swapdims(self, dim0, dim1); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL( "torch.ops.aten.unfold.default", @@ -343,12 +343,12 @@ REGISTER_NATIVE_CPU_KERNEL( const auto step = KernelInput(3).toInt(); KernelOutput(0) = at::native::unfold(self, dimension, size, step); return; - }); + }) REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.alias.default", aten_alias_default, { const auto& self = KernelInput(0).toTensor(); KernelOutput(0) = at::native::alias(self); return; -}); +}) } // namespace torch::nativert diff --git a/torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp b/torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp index 986eb060cb0f..f919639f48de 100644 --- a/torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp +++ b/torch/nativert/kernels/GeneratedStaticDispatchKernels.cpp @@ -41,7 +41,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.absolute.default", aten_absolute_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::absolute_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.angle.default", aten_angle_default, { const auto& self = KernelInput(0).toTensor(); @@ -52,7 +52,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.angle.default", aten_angle_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::angle_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.sgn.default", aten_sgn_default, { const auto& self = KernelInput(0).toTensor(); @@ -63,7 +63,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.sgn.default", aten_sgn_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::sgn_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.acos.default", aten_acos_default, { const auto& self = KernelInput(0).toTensor(); @@ -74,7 +74,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.acos.default", aten_acos_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::acos_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.arccos.default", aten_arccos_default, { const auto& self = KernelInput(0).toTensor(); @@ -85,7 +85,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.arccos.default", aten_arccos_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::arccos_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.add.Tensor", aten_add_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -98,7 +98,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.add.Tensor", aten_add_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::add_out(out, self, other, alpha); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.add.Scalar", aten_add_Scalar, { const auto& self = KernelInput(0).toTensor(); @@ -110,7 +110,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.add.Scalar", aten_add_Scalar, { auto& out_t = KernelOutput(0).toTensor(); fastResizeToZero(out_t); at::add_out(out_t, self, other, alpha); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten._add_relu.Tensor", aten__add_relu_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -123,7 +123,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten._add_relu.Tensor", aten__add_relu_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::add_relu_out(self, other, alpha, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.addmv.default", aten_addmv_default, { const auto& self = KernelInput(0).toTensor(); @@ -138,7 +138,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.addmv.default", aten_addmv_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::addmv_out(out, self, mat, vec, beta, alpha); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.addr.default", aten_addr_default, { const auto& self = KernelInput(0).toTensor(); @@ -153,7 +153,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.addr.default", aten_addr_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::addr_out(self, vec1, vec2, beta, alpha, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.all.dim", aten_all_dim, { const auto& self = KernelInput(0).toTensor(); @@ -166,7 +166,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.all.dim", aten_all_dim, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::all_out(out, self, dim, keepdim); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.any.dim", aten_any_dim, { const auto& self = KernelInput(0).toTensor(); @@ -179,7 +179,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.any.dim", aten_any_dim, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::any_out(out, self, dim, keepdim); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.argmax.default", aten_argmax_default, { const auto& self = KernelInput(0).toTensor(); @@ -192,7 +192,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.argmax.default", aten_argmax_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::argmax_out(out, self, dim, keepdim); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.acosh.default", aten_acosh_default, { const auto& self = KernelInput(0).toTensor(); @@ -203,7 +203,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.acosh.default", aten_acosh_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::acosh_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.asinh.default", aten_asinh_default, { const auto& self = KernelInput(0).toTensor(); @@ -214,7 +214,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.asinh.default", aten_asinh_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::asinh_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.arcsinh.default", aten_arcsinh_default, { const auto& self = KernelInput(0).toTensor(); @@ -225,7 +225,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.arcsinh.default", aten_arcsinh_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::arcsinh_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.atanh.default", aten_atanh_default, { const auto& self = KernelInput(0).toTensor(); @@ -236,7 +236,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.atanh.default", aten_atanh_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::atanh_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.arctanh.default", aten_arctanh_default, { const auto& self = KernelInput(0).toTensor(); @@ -247,7 +247,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.arctanh.default", aten_arctanh_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::arctanh_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.asin.default", aten_asin_default, { const auto& self = KernelInput(0).toTensor(); @@ -258,7 +258,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.asin.default", aten_asin_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::asin_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.arcsin.default", aten_arcsin_default, { const auto& self = KernelInput(0).toTensor(); @@ -269,7 +269,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.arcsin.default", aten_arcsin_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::arcsin_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.atan.default", aten_atan_default, { const auto& self = KernelInput(0).toTensor(); @@ -280,7 +280,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.atan.default", aten_atan_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::atan_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.arctan.default", aten_arctan_default, { const auto& self = KernelInput(0).toTensor(); @@ -291,7 +291,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.arctan.default", aten_arctan_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::arctan_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.baddbmm.default", aten_baddbmm_default, { const auto& self = KernelInput(0).toTensor(); @@ -306,7 +306,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.baddbmm.default", aten_baddbmm_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::baddbmm_out(out, self, batch1, batch2, beta, alpha); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.bitwise_not.default", @@ -320,7 +320,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::bitwise_not_out(out, self); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.copysign.Tensor", aten_copysign_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -332,7 +332,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.copysign.Tensor", aten_copysign_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::copysign_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.logical_not.default", @@ -346,7 +346,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::logical_not_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.logical_xor.default", @@ -361,7 +361,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::logical_xor_out(self, other, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.logical_and.default", @@ -376,7 +376,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::logical_and_out(self, other, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.logical_or.default", @@ -391,7 +391,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::logical_or_out(self, other, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.ceil.default", aten_ceil_default, { const auto& self = KernelInput(0).toTensor(); @@ -402,7 +402,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.ceil.default", aten_ceil_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::ceil_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.clamp.default", aten_clamp_default, { const auto& self = KernelInput(0).toTensor(); @@ -415,7 +415,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.clamp.default", aten_clamp_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::clamp_out(out, self, min, max); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.clamp.Tensor", aten_clamp_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -428,7 +428,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.clamp.Tensor", aten_clamp_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::clamp_out(out, self, min, max); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.clamp_max.default", @@ -443,7 +443,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::clamp_max_out(out, self, max); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.clamp_max.Tensor", aten_clamp_max_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -455,7 +455,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.clamp_max.Tensor", aten_clamp_max_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::clamp_max_out(out, self, max); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.clip.default", aten_clip_default, { const auto& self = KernelInput(0).toTensor(); @@ -468,7 +468,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.clip.default", aten_clip_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::clip_out(self, min, max, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.complex.default", aten_complex_default, { const auto& real = KernelInput(0).toTensor(); @@ -480,7 +480,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.complex.default", aten_complex_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::complex_out(real, imag, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.polar.default", aten_polar_default, { const auto& abs = KernelInput(0).toTensor(); @@ -492,7 +492,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.polar.default", aten_polar_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::polar_out(abs, angle, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.cos.default", aten_cos_default, { const auto& self = KernelInput(0).toTensor(); @@ -503,7 +503,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.cos.default", aten_cos_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::cos_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.cosh.default", aten_cosh_default, { const auto& self = KernelInput(0).toTensor(); @@ -514,7 +514,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.cosh.default", aten_cosh_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::cosh_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.cumprod.default", aten_cumprod_default, { const auto& self = KernelInput(0).toTensor(); @@ -527,7 +527,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.cumprod.default", aten_cumprod_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::cumprod_out(out, self, dim, dtype); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.diff.default", aten_diff_default, { const auto& self = KernelInput(0).toTensor(); @@ -542,7 +542,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.diff.default", aten_diff_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::diff_out(self, n, dim, prepend, append, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.div.Tensor", aten_div_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -554,7 +554,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.div.Tensor", aten_div_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::div_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.div.Tensor_mode", aten_div_Tensor_mode, { const auto& self = KernelInput(0).toTensor(); @@ -567,7 +567,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.div.Tensor_mode", aten_div_Tensor_mode, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::div_out(out, self, other, rounding_mode); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.divide.Tensor", aten_divide_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -579,7 +579,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.divide.Tensor", aten_divide_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::divide_out(self, other, out); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.true_divide.Tensor", @@ -594,7 +594,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::true_divide_out(self, other, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.dot.default", aten_dot_default, { const auto& self = KernelInput(0).toTensor(); @@ -606,7 +606,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.dot.default", aten_dot_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::dot_out(self, tensor, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.vdot.default", aten_vdot_default, { const auto& self = KernelInput(0).toTensor(); @@ -618,7 +618,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.vdot.default", aten_vdot_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::vdot_out(self, other, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.erf.default", aten_erf_default, { const auto& self = KernelInput(0).toTensor(); @@ -629,7 +629,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.erf.default", aten_erf_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::erf_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.erfc.default", aten_erfc_default, { const auto& self = KernelInput(0).toTensor(); @@ -640,7 +640,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.erfc.default", aten_erfc_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::erfc_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.exp.default", aten_exp_default, { const auto& self = KernelInput(0).toTensor(); @@ -651,7 +651,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.exp.default", aten_exp_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::exp_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.exp2.default", aten_exp2_default, { const auto& self = KernelInput(0).toTensor(); @@ -662,7 +662,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.exp2.default", aten_exp2_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::exp2_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.expm1.default", aten_expm1_default, { const auto& self = KernelInput(0).toTensor(); @@ -673,7 +673,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.expm1.default", aten_expm1_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::expm1_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.floor.default", aten_floor_default, { const auto& self = KernelInput(0).toTensor(); @@ -684,7 +684,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.floor.default", aten_floor_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::floor_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.frac.default", aten_frac_default, { const auto& self = KernelInput(0).toTensor(); @@ -695,7 +695,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.frac.default", aten_frac_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::frac_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.gcd.default", aten_gcd_default, { const auto& self = KernelInput(0).toTensor(); @@ -707,7 +707,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.gcd.default", aten_gcd_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::gcd_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.lcm.default", aten_lcm_default, { const auto& self = KernelInput(0).toTensor(); @@ -719,7 +719,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.lcm.default", aten_lcm_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::lcm_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.index_copy.default", @@ -736,7 +736,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::index_copy_out(out, self, dim, index, source); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.isin.Tensor_Tensor", @@ -754,7 +754,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::isin_out(out, elements, test_elements, assume_unique, invert); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.isin.Tensor_Scalar", @@ -772,7 +772,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::isin_out(out, elements, test_element, assume_unique, invert); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.isin.Scalar_Tensor", @@ -790,7 +790,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::isin_out(out, element, test_elements, assume_unique, invert); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.kron.default", aten_kron_default, { const auto& self = KernelInput(0).toTensor(); @@ -802,7 +802,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.kron.default", aten_kron_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::kron_out(self, other, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.ldexp.Tensor", aten_ldexp_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -814,7 +814,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.ldexp.Tensor", aten_ldexp_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::ldexp_out(self, other, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.log10.default", aten_log10_default, { const auto& self = KernelInput(0).toTensor(); @@ -825,7 +825,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.log10.default", aten_log10_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::log10_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.log1p.default", aten_log1p_default, { const auto& self = KernelInput(0).toTensor(); @@ -836,7 +836,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.log1p.default", aten_log1p_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::log1p_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.log2.default", aten_log2_default, { const auto& self = KernelInput(0).toTensor(); @@ -847,7 +847,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.log2.default", aten_log2_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::log2_out(out, self); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.logaddexp.default", @@ -862,7 +862,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::logaddexp_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.logaddexp2.default", @@ -877,7 +877,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::logaddexp2_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.xlogy.Tensor", aten_xlogy_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -889,7 +889,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.xlogy.Tensor", aten_xlogy_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::xlogy_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten._log_softmax.default", @@ -905,7 +905,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::_log_softmax_out(out, self, dim, half_to_float); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten._logcumsumexp.default", @@ -920,7 +920,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::_logcumsumexp_out_cpu(self, dim, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.logcumsumexp.default", @@ -935,7 +935,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::logcumsumexp_out(self, dim, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.matrix_power.default", @@ -950,7 +950,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::matrix_power_out(self, n, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.mm.default", aten_mm_default, { const auto& self = KernelInput(0).toTensor(); @@ -962,7 +962,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.mm.default", aten_mm_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::mm_out(out, self, mat2); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.multiply.Tensor", aten_multiply_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -974,7 +974,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.multiply.Tensor", aten_multiply_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::multiply_out(self, other, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.mv.default", aten_mv_default, { const auto& self = KernelInput(0).toTensor(); @@ -986,7 +986,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.mv.default", aten_mv_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::mv_out(self, vec, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.mvlgamma.default", aten_mvlgamma_default, { const auto& self = KernelInput(0).toTensor(); @@ -998,7 +998,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.mvlgamma.default", aten_mvlgamma_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::mvlgamma_out(self, p, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.rad2deg.default", aten_rad2deg_default, { const auto& self = KernelInput(0).toTensor(); @@ -1009,7 +1009,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.rad2deg.default", aten_rad2deg_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::rad2deg_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.deg2rad.default", aten_deg2rad_default, { const auto& self = KernelInput(0).toTensor(); @@ -1020,7 +1020,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.deg2rad.default", aten_deg2rad_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::deg2rad_out(self, out); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.reciprocal.default", @@ -1034,7 +1034,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::reciprocal_out(out, self); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.neg.default", aten_neg_default, { const auto& self = KernelInput(0).toTensor(); @@ -1045,7 +1045,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.neg.default", aten_neg_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::neg_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.negative.default", aten_negative_default, { const auto& self = KernelInput(0).toTensor(); @@ -1056,7 +1056,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.negative.default", aten_negative_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::negative_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.round.default", aten_round_default, { const auto& self = KernelInput(0).toTensor(); @@ -1067,7 +1067,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.round.default", aten_round_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::round_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.round.decimals", aten_round_decimals, { const auto& self = KernelInput(0).toTensor(); @@ -1079,7 +1079,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.round.decimals", aten_round_decimals, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::round_out(out, self, decimals); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.gelu.default", aten_gelu_default, { const auto& self = KernelInput(0).toTensor(); @@ -1091,7 +1091,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.gelu.default", aten_gelu_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::gelu_out(out, self, approximate); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.hardshrink.default", @@ -1106,7 +1106,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::hardshrink_out(out, self, lambd); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.hardshrink_backward.default", @@ -1122,7 +1122,7 @@ REGISTER_CPU_KERNEL( auto& grad_input = KernelOutput(0).toTensor(); fastResizeToZero(grad_input); at::cpu::hardshrink_backward_out(grad_input, grad_out, self, lambd); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.rsqrt.default", aten_rsqrt_default, { const auto& self = KernelInput(0).toTensor(); @@ -1133,7 +1133,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.rsqrt.default", aten_rsqrt_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::rsqrt_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.silu.default", aten_silu_default, { const auto& self = KernelInput(0).toTensor(); @@ -1144,7 +1144,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.silu.default", aten_silu_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::silu_out(out, self); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.silu_backward.default", @@ -1159,7 +1159,7 @@ REGISTER_CPU_KERNEL( auto& grad_input = KernelOutput(0).toTensor(); fastResizeToZero(grad_input); at::cpu::silu_backward_out(grad_input, grad_output, self); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.mish.default", aten_mish_default, { const auto& self = KernelInput(0).toTensor(); @@ -1170,7 +1170,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.mish.default", aten_mish_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::mish_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.sigmoid.default", aten_sigmoid_default, { const auto& self = KernelInput(0).toTensor(); @@ -1181,7 +1181,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.sigmoid.default", aten_sigmoid_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::sigmoid_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.sin.default", aten_sin_default, { const auto& self = KernelInput(0).toTensor(); @@ -1192,7 +1192,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.sin.default", aten_sin_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::sin_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.sinc.default", aten_sinc_default, { const auto& self = KernelInput(0).toTensor(); @@ -1203,7 +1203,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.sinc.default", aten_sinc_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::sinc_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.sinh.default", aten_sinh_default, { const auto& self = KernelInput(0).toTensor(); @@ -1214,7 +1214,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.sinh.default", aten_sinh_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::sinh_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten._softmax.default", aten__softmax_default, { const auto& self = KernelInput(0).toTensor(); @@ -1227,7 +1227,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten._softmax.default", aten__softmax_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::_softmax_out(out, self, dim, half_to_float); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.sqrt.default", aten_sqrt_default, { const auto& self = KernelInput(0).toTensor(); @@ -1238,7 +1238,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.sqrt.default", aten_sqrt_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::sqrt_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.square.default", aten_square_default, { const auto& self = KernelInput(0).toTensor(); @@ -1249,7 +1249,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.square.default", aten_square_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::square_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.prod.default", aten_prod_default, { const auto& self = KernelInput(0).toTensor(); @@ -1261,7 +1261,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.prod.default", aten_prod_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::prod_out(self, dtype, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.prod.dim_int", aten_prod_dim_int, { const auto& self = KernelInput(0).toTensor(); @@ -1275,7 +1275,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.prod.dim_int", aten_prod_dim_int, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::prod_out(out, self, dim, keepdim, dtype); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.tan.default", aten_tan_default, { const auto& self = KernelInput(0).toTensor(); @@ -1286,7 +1286,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.tan.default", aten_tan_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::tan_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.tanh.default", aten_tanh_default, { const auto& self = KernelInput(0).toTensor(); @@ -1297,7 +1297,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.tanh.default", aten_tanh_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::tanh_out(out, self); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.threshold.default", @@ -1313,7 +1313,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::threshold_out(out, self, threshold, value); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.threshold_backward.default", @@ -1330,7 +1330,7 @@ REGISTER_CPU_KERNEL( auto& grad_input = KernelOutput(0).toTensor(); fastResizeToZero(grad_input); at::cpu::threshold_backward_out(grad_input, grad_output, self, threshold); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.trunc.default", aten_trunc_default, { const auto& self = KernelInput(0).toTensor(); @@ -1341,7 +1341,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.trunc.default", aten_trunc_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::trunc_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.fix.default", aten_fix_default, { const auto& self = KernelInput(0).toTensor(); @@ -1352,7 +1352,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.fix.default", aten_fix_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::fix_out(self, out); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.nuclear_norm.default", @@ -1367,7 +1367,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::nuclear_norm_out(self, keepdim, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.subtract.Tensor", aten_subtract_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -1380,7 +1380,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.subtract.Tensor", aten_subtract_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::subtract_out(self, other, alpha, out); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.heaviside.default", @@ -1395,7 +1395,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::heaviside_out(out, self, values); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten._addmm_activation.default", @@ -1416,7 +1416,7 @@ REGISTER_CPU_KERNEL( fastResizeToZero(out); at::cpu::_addmm_activation_out( out, self, mat1, mat2, beta, alpha, use_gelu); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.index_add.default", @@ -1434,7 +1434,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::index_add_out(out, self, dim, index, source, alpha); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.scatter.src", aten_scatter_src, { const auto& self = KernelInput(0).toTensor(); @@ -1448,7 +1448,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.scatter.src", aten_scatter_src, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::scatter_out(out, self, dim, index, src); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.scatter.value", aten_scatter_value, { const auto& self = KernelInput(0).toTensor(); @@ -1462,7 +1462,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.scatter.value", aten_scatter_value, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::scatter_out(out, self, dim, index, value); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.scatter.reduce", aten_scatter_reduce, { const auto& self = KernelInput(0).toTensor(); @@ -1477,7 +1477,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.scatter.reduce", aten_scatter_reduce, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::scatter_out(out, self, dim, index, src, reduce); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.scatter.value_reduce", @@ -1495,7 +1495,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::scatter_out(out, self, dim, index, value, reduce); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.scatter_add.default", @@ -1512,7 +1512,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::scatter_add_out(out, self, dim, index, src); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.scatter_reduce.two", @@ -1533,7 +1533,7 @@ REGISTER_CPU_KERNEL( fastResizeToZero(out); at::cpu::scatter_reduce_out( out, self, dim, index, src, reduce, include_self); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.eq.Scalar", aten_eq_Scalar, { const auto& self = KernelInput(0).toTensor(); @@ -1545,7 +1545,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.eq.Scalar", aten_eq_Scalar, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::eq_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.eq.Tensor", aten_eq_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -1557,7 +1557,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.eq.Tensor", aten_eq_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::eq_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.bitwise_and.Tensor", @@ -1572,7 +1572,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::bitwise_and_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.bitwise_or.Tensor", @@ -1587,7 +1587,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::bitwise_or_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.bitwise_xor.Tensor", @@ -1602,7 +1602,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::bitwise_xor_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.bitwise_left_shift.Tensor", @@ -1617,7 +1617,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::bitwise_left_shift_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.bitwise_right_shift.Tensor", @@ -1632,7 +1632,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::bitwise_right_shift_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.tril.default", aten_tril_default, { const auto& self = KernelInput(0).toTensor(); @@ -1644,7 +1644,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.tril.default", aten_tril_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::tril_out(out, self, diagonal); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.triu.default", aten_triu_default, { const auto& self = KernelInput(0).toTensor(); @@ -1656,7 +1656,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.triu.default", aten_triu_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::triu_out(out, self, diagonal); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.digamma.default", aten_digamma_default, { const auto& self = KernelInput(0).toTensor(); @@ -1667,7 +1667,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.digamma.default", aten_digamma_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::digamma_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.lerp.Scalar", aten_lerp_Scalar, { const auto& self = KernelInput(0).toTensor(); @@ -1680,7 +1680,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.lerp.Scalar", aten_lerp_Scalar, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::lerp_out(out, self, end, weight); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.lerp.Tensor", aten_lerp_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -1693,7 +1693,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.lerp.Tensor", aten_lerp_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::lerp_out(out, self, end, weight); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.addbmm.default", aten_addbmm_default, { const auto& self = KernelInput(0).toTensor(); @@ -1708,7 +1708,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.addbmm.default", aten_addbmm_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::addbmm_out(self, batch1, batch2, beta, alpha, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.cross.default", aten_cross_default, { const auto& self = KernelInput(0).toTensor(); @@ -1721,7 +1721,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.cross.default", aten_cross_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::cross_out(self, other, dim, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.ne.Scalar", aten_ne_Scalar, { const auto& self = KernelInput(0).toTensor(); @@ -1733,7 +1733,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.ne.Scalar", aten_ne_Scalar, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::ne_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.ne.Tensor", aten_ne_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -1745,7 +1745,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.ne.Tensor", aten_ne_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::ne_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.ge.Scalar", aten_ge_Scalar, { const auto& self = KernelInput(0).toTensor(); @@ -1757,7 +1757,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.ge.Scalar", aten_ge_Scalar, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::ge_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.ge.Tensor", aten_ge_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -1769,7 +1769,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.ge.Tensor", aten_ge_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::ge_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.le.Scalar", aten_le_Scalar, { const auto& self = KernelInput(0).toTensor(); @@ -1781,7 +1781,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.le.Scalar", aten_le_Scalar, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::le_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.le.Tensor", aten_le_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -1793,7 +1793,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.le.Tensor", aten_le_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::le_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.gt.Scalar", aten_gt_Scalar, { const auto& self = KernelInput(0).toTensor(); @@ -1805,7 +1805,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.gt.Scalar", aten_gt_Scalar, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::gt_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.gt.Tensor", aten_gt_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -1817,7 +1817,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.gt.Tensor", aten_gt_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::gt_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.lt.Scalar", aten_lt_Scalar, { const auto& self = KernelInput(0).toTensor(); @@ -1829,7 +1829,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.lt.Scalar", aten_lt_Scalar, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::lt_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.lt.Tensor", aten_lt_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -1841,7 +1841,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.lt.Tensor", aten_lt_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::lt_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.take.default", aten_take_default, { const auto& self = KernelInput(0).toTensor(); @@ -1853,7 +1853,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.take.default", aten_take_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::take_out(self, index, out); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.take_along_dim.default", @@ -1869,7 +1869,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::take_along_dim_out(self, indices, dim, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.masked_select.default", @@ -1884,7 +1884,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::masked_select_out_cpu(self, mask, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.gather.default", aten_gather_default, { const auto& self = KernelInput(0).toTensor(); @@ -1898,7 +1898,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.gather.default", aten_gather_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::gather_out(out, self, dim, index, sparse_grad); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.addcmul.default", aten_addcmul_default, { const auto& self = KernelInput(0).toTensor(); @@ -1912,7 +1912,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.addcmul.default", aten_addcmul_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::addcmul_out(out, self, tensor1, tensor2, value); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.addcdiv.default", aten_addcdiv_default, { const auto& self = KernelInput(0).toTensor(); @@ -1926,7 +1926,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.addcdiv.default", aten_addcdiv_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::addcdiv_out(out, self, tensor1, tensor2, value); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_solve_triangular.default", @@ -1946,7 +1946,7 @@ REGISTER_CPU_KERNEL( fastResizeToZero(out); at::native::linalg_solve_triangular_out( self, B, upper, left, unitriangular, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.cholesky_solve.default", @@ -1962,7 +1962,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::cholesky_solve_out(self, input2, upper, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.cholesky_inverse.default", @@ -1977,7 +1977,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::cholesky_inverse_out(self, upper, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.orgqr.default", aten_orgqr_default, { const auto& self = KernelInput(0).toTensor(); @@ -1989,7 +1989,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.orgqr.default", aten_orgqr_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::orgqr_out(self, input2, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.ormqr.default", aten_ormqr_default, { const auto& self = KernelInput(0).toTensor(); @@ -2004,7 +2004,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.ormqr.default", aten_ormqr_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::ormqr_out(self, input2, input3, left, transpose, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.lgamma.default", aten_lgamma_default, { const auto& self = KernelInput(0).toTensor(); @@ -2015,7 +2015,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.lgamma.default", aten_lgamma_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::lgamma_out(out, self); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.polygamma.default", @@ -2030,7 +2030,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::polygamma_out(out, n, self); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.erfinv.default", aten_erfinv_default, { const auto& self = KernelInput(0).toTensor(); @@ -2041,7 +2041,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.erfinv.default", aten_erfinv_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::erfinv_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.i0.default", aten_i0_default, { const auto& self = KernelInput(0).toTensor(); @@ -2052,7 +2052,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.i0.default", aten_i0_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::i0_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.signbit.default", aten_signbit_default, { const auto& self = KernelInput(0).toTensor(); @@ -2063,7 +2063,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.signbit.default", aten_signbit_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::signbit_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.atan2.default", aten_atan2_default, { const auto& self = KernelInput(0).toTensor(); @@ -2075,7 +2075,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.atan2.default", aten_atan2_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::atan2_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.arctan2.default", aten_arctan2_default, { const auto& self = KernelInput(0).toTensor(); @@ -2087,7 +2087,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.arctan2.default", aten_arctan2_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::arctan2_out(self, other, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.histc.default", aten_histc_default, { const auto& self = KernelInput(0).toTensor(); @@ -2101,7 +2101,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.histc.default", aten_histc_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::histogram_histc_out(self, bins, min, max, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.fmod.Tensor", aten_fmod_Tensor, { const auto& self = KernelInput(0).toTensor(); @@ -2113,7 +2113,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.fmod.Tensor", aten_fmod_Tensor, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::fmod_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.hypot.default", aten_hypot_default, { const auto& self = KernelInput(0).toTensor(); @@ -2125,7 +2125,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.hypot.default", aten_hypot_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::hypot_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.igamma.default", aten_igamma_default, { const auto& self = KernelInput(0).toTensor(); @@ -2137,7 +2137,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.igamma.default", aten_igamma_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::igamma_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.igammac.default", aten_igammac_default, { const auto& self = KernelInput(0).toTensor(); @@ -2149,7 +2149,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.igammac.default", aten_igammac_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::igammac_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.nextafter.default", @@ -2164,7 +2164,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::nextafter_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.fmin.default", aten_fmin_default, { const auto& self = KernelInput(0).toTensor(); @@ -2176,7 +2176,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.fmin.default", aten_fmin_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::fmin_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.fmax.default", aten_fmax_default, { const auto& self = KernelInput(0).toTensor(); @@ -2188,7 +2188,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.fmax.default", aten_fmax_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::fmax_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.maximum.default", aten_maximum_default, { const auto& self = KernelInput(0).toTensor(); @@ -2200,7 +2200,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.maximum.default", aten_maximum_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::maximum_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.minimum.default", aten_minimum_default, { const auto& self = KernelInput(0).toTensor(); @@ -2212,7 +2212,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.minimum.default", aten_minimum_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::minimum_out(out, self, other); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.min.other", aten_min_other, { const auto& self = KernelInput(0).toTensor(); @@ -2224,7 +2224,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.min.other", aten_min_other, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::min_out(self, other, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.quantile.default", aten_quantile_default, { const auto& self = KernelInput(0).toTensor(); @@ -2240,7 +2240,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.quantile.default", aten_quantile_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::quantile_out(self, q, dim, keepdim, interpolation, out); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.nanquantile.default", @@ -2259,7 +2259,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::nanquantile_out(self, q, dim, keepdim, interpolation, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.msort.default", aten_msort_default, { const auto& self = KernelInput(0).toTensor(); @@ -2270,7 +2270,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.msort.default", aten_msort_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::msort_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.all.default", aten_all_default, { const auto& self = KernelInput(0).toTensor(); @@ -2281,7 +2281,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.all.default", aten_all_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::all_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.any.default", aten_any_default, { const auto& self = KernelInput(0).toTensor(); @@ -2292,7 +2292,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.any.default", aten_any_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::any_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.renorm.default", aten_renorm_default, { const auto& self = KernelInput(0).toTensor(); @@ -2306,7 +2306,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.renorm.default", aten_renorm_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::renorm_out(out, self, p, dim, maxnorm); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten._convert_indices_from_coo_to_csr.default", @@ -2323,7 +2323,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::_convert_indices_from_coo_to_csr_out(out, self, size, out_int32); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten._convert_indices_from_csr_to_coo.default", @@ -2342,7 +2342,7 @@ REGISTER_CPU_KERNEL( fastResizeToZero(out); at::cpu::_convert_indices_from_csr_to_coo_out( out, crow_indices, col_indices, out_int32, transpose); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.mse_loss.default", aten_mse_loss_default, { const auto& self = KernelInput(0).toTensor(); @@ -2355,7 +2355,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.mse_loss.default", aten_mse_loss_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::mse_loss_out(out, self, target, reduction); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.multi_margin_loss.default", @@ -2376,7 +2376,7 @@ REGISTER_CPU_KERNEL( fastResizeToZero(out); at::native::multi_margin_loss_cpu_out( self, target, p, margin, weight, reduction, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.multilabel_margin_loss.default", @@ -2393,7 +2393,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::multilabel_margin_loss_out(self, target, reduction, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.soft_margin_loss.default", @@ -2409,7 +2409,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::soft_margin_loss_out(self, target, reduction, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.elu.default", aten_elu_default, { const auto& self = KernelInput(0).toTensor(); @@ -2423,7 +2423,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.elu.default", aten_elu_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::elu_out(out, self, alpha, scale, input_scale); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.elu_backward.default", @@ -2450,7 +2450,7 @@ REGISTER_CPU_KERNEL( input_scale, is_result, self_or_result); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.glu.default", aten_glu_default, { const auto& self = KernelInput(0).toTensor(); @@ -2462,7 +2462,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.glu.default", aten_glu_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::glu_out(out, self, dim); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.hardsigmoid.default", @@ -2476,7 +2476,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::hardsigmoid_out(out, self); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.hardsigmoid_backward.default", @@ -2491,7 +2491,7 @@ REGISTER_CPU_KERNEL( auto& grad_input = KernelOutput(0).toTensor(); fastResizeToZero(grad_input); at::cpu::hardsigmoid_backward_out(grad_input, grad_output, self); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.hardtanh.default", aten_hardtanh_default, { const auto& self = KernelInput(0).toTensor(); @@ -2504,7 +2504,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.hardtanh.default", aten_hardtanh_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::hardtanh_out(self, min_val, max_val, out); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.hardswish.default", @@ -2518,7 +2518,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::hardswish_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.leaky_relu_backward.default", @@ -2537,7 +2537,7 @@ REGISTER_CPU_KERNEL( fastResizeToZero(grad_input); at::cpu::leaky_relu_backward_out( grad_input, grad_output, self, negative_slope, self_is_result); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.log_sigmoid.default", @@ -2551,7 +2551,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::log_sigmoid_out(self, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.softplus.default", aten_softplus_default, { const auto& self = KernelInput(0).toTensor(); @@ -2564,7 +2564,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.softplus.default", aten_softplus_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::softplus_out(out, self, beta, threshold); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.softplus_backward.default", @@ -2583,7 +2583,7 @@ REGISTER_CPU_KERNEL( fastResizeToZero(grad_input); at::cpu::softplus_backward_out( grad_input, grad_output, self, beta, threshold); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.softshrink.default", @@ -2598,7 +2598,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::softshrink_out(out, self, lambd); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.softshrink_backward.default", @@ -2615,7 +2615,7 @@ REGISTER_CPU_KERNEL( auto& grad_input = KernelOutput(0).toTensor(); fastResizeToZero(grad_input); at::cpu::softshrink_backward_out(grad_input, grad_output, self, lambd); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.isposinf.default", aten_isposinf_default, { const auto& self = KernelInput(0).toTensor(); @@ -2626,7 +2626,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.isposinf.default", aten_isposinf_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::isposinf_out(out, self); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.isneginf.default", aten_isneginf_default, { const auto& self = KernelInput(0).toTensor(); @@ -2637,7 +2637,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.isneginf.default", aten_isneginf_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::isneginf_out(out, self); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.special_entr.default", @@ -2651,7 +2651,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::special_entr_out(out, self); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_ndtri.default", @@ -2665,7 +2665,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::special_ndtri_out(out, self); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_log_ndtr.default", @@ -2679,7 +2679,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::special_log_ndtr_out(out, self); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_expm1.default", @@ -2693,7 +2693,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_expm1_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_exp2.default", @@ -2707,7 +2707,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_exp2_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_psi.default", @@ -2721,7 +2721,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_psi_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_digamma.default", @@ -2735,7 +2735,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_digamma_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_gammaln.default", @@ -2749,7 +2749,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_gammaln_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_erf.default", @@ -2763,7 +2763,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_erf_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_erfc.default", @@ -2777,7 +2777,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_erfc_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_erfcx.default", @@ -2791,7 +2791,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::special_erfcx_out(out, self); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_erfinv.default", @@ -2805,7 +2805,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_erfinv_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_ndtr.default", @@ -2819,7 +2819,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_ndtr_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_xlog1py.default", @@ -2834,7 +2834,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::special_xlog1py_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_xlogy.default", @@ -2849,7 +2849,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_xlogy_out(self, other, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_zeta.default", @@ -2864,7 +2864,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::special_zeta_out(out, self, other); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_i0.default", @@ -2878,7 +2878,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_i0_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_i0e.default", @@ -2892,7 +2892,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::special_i0e_out(out, self); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_i1.default", @@ -2906,7 +2906,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::special_i1_out(out, self); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_i1e.default", @@ -2920,7 +2920,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::special_i1e_out(out, self); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_polygamma.default", @@ -2935,7 +2935,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_polygamma_out(n, self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_expit.default", @@ -2949,7 +2949,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_expit_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_sinc.default", @@ -2963,7 +2963,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_sinc_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_round.default", @@ -2978,7 +2978,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_round_out(self, decimals, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_log1p.default", @@ -2992,7 +2992,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_log1p_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_gammainc.default", @@ -3007,7 +3007,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_gammainc_out(self, other, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_gammaincc.default", @@ -3022,7 +3022,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_gammaincc_out(self, other, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.special_multigammaln.default", @@ -3037,7 +3037,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::special_multigammaln_out(self, p, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_cross.default", @@ -3053,7 +3053,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::cpu::linalg_cross_out(out, self, other, dim); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_det.default", @@ -3067,7 +3067,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::linalg_det_out(A, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_matmul.default", @@ -3082,7 +3082,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::linalg_matmul_out(self, other, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_eigvals.default", @@ -3096,7 +3096,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::linalg_eigvals_out(self, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_inv.default", @@ -3110,7 +3110,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::linalg_inv_out(A, out); - }); + }) REGISTER_CPU_KERNEL("torch.ops.aten.inverse.default", aten_inverse_default, { const auto& self = KernelInput(0).toTensor(); @@ -3121,7 +3121,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.inverse.default", aten_inverse_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::inverse_out(self, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.inner.default", aten_inner_default, { const auto& self = KernelInput(0).toTensor(); @@ -3133,7 +3133,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.inner.default", aten_inner_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::inner_out(self, other, out); -}); +}) REGISTER_CPU_KERNEL("torch.ops.aten.outer.default", aten_outer_default, { const auto& self = KernelInput(0).toTensor(); @@ -3145,7 +3145,7 @@ REGISTER_CPU_KERNEL("torch.ops.aten.outer.default", aten_outer_default, { auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::outer_out(self, vec2, out); -}); +}) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_cond.default", @@ -3160,7 +3160,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::linalg_cond_out(self, p, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_solve.default", @@ -3176,7 +3176,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::linalg_solve_out(A, B, left, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_tensorinv.default", @@ -3191,7 +3191,7 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::linalg_tensorinv_out(self, ind, out); - }); + }) REGISTER_CPU_KERNEL( "torch.ops.aten.linalg_matrix_power.default", @@ -3206,6 +3206,6 @@ REGISTER_CPU_KERNEL( auto& out = KernelOutput(0).toTensor(); fastResizeToZero(out); at::native::linalg_matrix_power_out(self, n, out); - }); + }) } // namespace torch::nativert diff --git a/torch/nativert/kernels/HigherOrderKernel.cpp b/torch/nativert/kernels/HigherOrderKernel.cpp index a1f1393c0188..370339c82f82 100644 --- a/torch/nativert/kernels/HigherOrderKernel.cpp +++ b/torch/nativert/kernels/HigherOrderKernel.cpp @@ -11,28 +11,28 @@ HigherOrderKernel::HigherOrderKernel( std::vector> graphExecutors) : OpKernel(node), graphExecutors_(std::move(graphExecutors)) { static constexpr std::string_view prefix = "torch.ops.higher_order."; - CHECK(c10::starts_with(node->target(), prefix)); + TORCH_CHECK(c10::starts_with(node->target(), prefix)); auto opName = node->target().substr(prefix.size()); if (opName == "cond") { opType_ = OpType::COND; // Checking torch.cond schema is as expected: // torch.cond(Tensor predicate, Graph graph1, Graph graph2, Tensor[] args) // -> Tensor[] - TORCH_CHECK_EQ(node_->attributes().size(), 2); - TORCH_CHECK_EQ(node_->inputs().size(), 2); + TORCH_CHECK(node_->attributes().size() == 2); + TORCH_CHECK(node_->inputs().size() == 2); } else if (opName == "while_loop") { opType_ = OpType::WHILE_LOOP; // Checking torch.while_loop schema is as expected: // torch.while_loop(Graph cond, Graph body, Tensor[] args, Tensor[] // additional) -> Tensor[] - TORCH_CHECK_EQ(node_->attributes().size(), 2); - TORCH_CHECK_EQ(node_->inputs().size(), 2); + TORCH_CHECK(node_->attributes().size() == 2); + TORCH_CHECK(node_->inputs().size() == 2); } else if (opName == "run_const_graph") { opType_ = OpType::RUN_CONST_GRAPH; // Checking torch.run_const_graph schema is as expected: // torch.run_const_graph(Graph graph, Tensor[] args) -> Tensor[] - TORCH_CHECK_GE(node_->attributes().size(), 1); - TORCH_CHECK_EQ(node_->inputs().size(), 1); + TORCH_CHECK(!node_->attributes().empty()); + TORCH_CHECK(node_->inputs().size() == 1); } else { throw std::runtime_error( fmt::format("Unknown higher order op: {}", opName)); diff --git a/torch/nativert/kernels/KernelFactory.cpp b/torch/nativert/kernels/KernelFactory.cpp index 1f72fef810d6..da524c8e46b9 100644 --- a/torch/nativert/kernels/KernelFactory.cpp +++ b/torch/nativert/kernels/KernelFactory.cpp @@ -62,7 +62,7 @@ c10::Device inferTargetDevice( } // namespace -inline constexpr std::string_view kSymIntOps[] = { +inline constexpr std::array kSymIntOps = { "_operator.floordiv", "_operator.mod", "torch.sym_int", @@ -72,7 +72,7 @@ inline constexpr std::string_view kSymIntOps[] = { "torch.sym_min", }; -inline constexpr std::string_view kSymBoolOps[] = { +inline constexpr std::array kSymBoolOps = { "_operator.eq", "_operator.ne", "_operator.le", @@ -83,14 +83,14 @@ inline constexpr std::string_view kSymBoolOps[] = { "torch.sym_not", }; -inline constexpr std::string_view kSymFloatOps[] = { +inline constexpr std::array kSymFloatOps = { "torch._sym_sqrt", "math.trunc", "_operator.neg", "_operator.truediv", }; -inline constexpr std::string_view kScalarBinaryOps[] = { +inline constexpr std::array kScalarBinaryOps = { "_operator.mul", "_operator.add", "_operator.sub", @@ -124,11 +124,11 @@ void KernelFactory::registerHandler( ExecutionKernels KernelFactory::initializeNodeKernels( const Graph& graph, - std::shared_ptr weights, + const std::shared_ptr& weights, const torch::nativert::ExecutorConfig& executorConfig, const Placement& placement, - std::shared_ptr pytorchStreamReader, - const MakeProxyExecutorFn& makeProxyExecutorFunc) { + const std::shared_ptr& + pytorchStreamReader) { std::vector> nodeKernels; std::vector> delegateExecutors; std::vector constFoldingExecutions; @@ -214,10 +214,12 @@ ExecutionKernels KernelFactory::initializeNodeKernels( const auto& subgraph = std::get>(attr.value); auto executionKernels = initializeNodeKernels( *subgraph, weights, executorConfig, placement); - CHECK(executionKernels.delegateExecutors.empty()) - << "HigherOrderKernel does not support delegates"; - CHECK(executionKernels.constFoldingExecutions.size() == 0) - << "HigherOrderKernel does not support const folding"; + TORCH_CHECK( + executionKernels.delegateExecutors.empty(), + "HigherOrderKernel does not support delegates"); + TORCH_CHECK( + executionKernels.constFoldingExecutions.empty(), + "HigherOrderKernel does not support const folding"); if (executorConfig.maxParallelOps > 1) { graphExecutors.emplace_back( std::unique_ptr(new ParallelGraphExecutor( diff --git a/torch/nativert/kernels/KernelFactory.h b/torch/nativert/kernels/KernelFactory.h index c01d64c3a017..3f341f1115d3 100644 --- a/torch/nativert/kernels/KernelFactory.h +++ b/torch/nativert/kernels/KernelFactory.h @@ -70,16 +70,15 @@ class KernelFactoryHandler { class KernelFactory { public: - explicit KernelFactory() {} + KernelFactory() = default; ExecutionKernels initializeNodeKernels( const Graph& graph, - std::shared_ptr weights, + const std::shared_ptr& weights, const torch::nativert::ExecutorConfig& executorConfig, const Placement& placement, - std::shared_ptr - pytorchStreamReader = nullptr, - const MakeProxyExecutorFn& makeProxyExecutorFunc = nullptr); + const std::shared_ptr& + pytorchStreamReader = nullptr); static void registerHandler( const std::string& name, diff --git a/torch/nativert/kernels/KernelRegistry.cpp b/torch/nativert/kernels/KernelRegistry.cpp new file mode 100644 index 000000000000..77da29528d45 --- /dev/null +++ b/torch/nativert/kernels/KernelRegistry.cpp @@ -0,0 +1,1380 @@ +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +namespace at::native { + +static void repeat_out( + at::Tensor& result, + const Tensor& self, + IntArrayRef repeats) { + TORCH_CHECK( + repeats.size() >= static_cast(self.dim()), + "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor"); + + // Add new leading dimensions to the tensor if the + // number of target dimensions is larger than the + // number of source dimensions. + int64_t num_new_dimensions = repeats.size() - self.dim(); + DimVector padded_size(num_new_dimensions, 1); + padded_size.insert( + padded_size.end(), self.sizes().begin(), self.sizes().end()); + DimVector target_size(repeats.size()); + bool zero_tensor = false; + for (const auto idx : c10::irange(repeats.size())) { + if (repeats[idx] == 0) { + zero_tensor = true; + } + target_size[idx] = padded_size[idx] * repeats[idx]; + } + + // return an empty tensor if one of the repeat dimensions is zero + at::native::resize_(result, target_size, std::nullopt); + if (zero_tensor) { + return; + } + + Tensor xtensor = at::compositeexplicitautograd::expand(self, padded_size); + Tensor urtensor = at::native::alias(result); + for (const auto i : c10::irange(xtensor.dim())) { + // can't unfold with step 0, so make sure step is at least 1 + // (it doesn't matter what it is in that case, because the size is 0). + urtensor = urtensor.unfold( + i, xtensor.size(i), std::max(xtensor.size(i), 1)); + } + + at::native::copy_(urtensor, xtensor.expand_as(urtensor)); +} + +static Tensor& c2_argmin_out( + Tensor& output, + const Tensor& input, + const int64_t dim, + const bool keepdim) { + const auto ndim = input.dim(); + int64_t dim_ = maybe_wrap_dim(dim, ndim); + TORCH_CHECK(dim_ >= 0 && dim_ < ndim); + + const auto in_dims = input.sizes(); + + c10::SmallVector out_dims; + out_dims.reserve(ndim); + int prev_size = 1; + int next_size = 1; + for (int i = 0; i < dim_; ++i) { + out_dims.push_back(in_dims[i]); + prev_size *= in_dims[i]; + } + if (keepdim) { + out_dims.push_back(1); + } + for (auto i = dim_ + 1; i < ndim; ++i) { + out_dims.push_back(in_dims[i]); + next_size *= in_dims[i]; + } + at::native::resize_(output, out_dims, std::nullopt); + + const auto n = in_dims[dim_]; + + if (next_size == 1) { + AT_DISPATCH_ALL_TYPES_AND2( + kHalf, kBFloat16, input.scalar_type(), "argmin_input", [&]() { + const auto in_ptr = input.const_data_ptr(); + const auto out_ptr = output.mutable_data_ptr(); + // input is a [prev_size, n] tensor. + // output is a [prev_size,] tensor. + // Thus, access is contiguous/coalesced. + for (int i = 0; i < prev_size; ++i) { + auto v = std::min_element( + in_ptr + i * n, + in_ptr + (i + 1) * n, + [](scalar_t a, scalar_t b) { + // if a is nan, then a is *less* than b with LessOrNan + // semantics + if (at::_isnan(a)) { + return true; + } + // if a is not nan and b is nan, then a is not less than b + // with LessOrNan semantics otherwise, act normally. If `b` is + // NaN then a < b will always return false, so this is + // equivalent to the first snippet. + return a < b; + }); + out_ptr[i] = std::distance(in_ptr + i * n, v); + } + }); + } else { + AT_DISPATCH_ALL_TYPES_AND2( + kHalf, kBFloat16, input.scalar_type(), "argmin_input", [&]() { + const auto less_or_nan = native::detail::LessOrNan{}; + + const auto in_ptr = input.const_data_ptr(); + const auto out_ptr = output.mutable_data_ptr(); + + std::memset(out_ptr, 0, prev_size * next_size * sizeof(int64_t)); + + for (int i = 0; i < prev_size; ++i) { + const scalar_t* cur_in_ptr = in_ptr + i * n * next_size + next_size; + for (int k = 1; k < n; ++k) { + for (int j = 0; j < next_size; ++j) { + int64_t* cur_out_ptr = out_ptr + i * next_size + j; + if (less_or_nan( + *cur_in_ptr, + in_ptr + [i * n * next_size + *cur_out_ptr * next_size + j], + *cur_out_ptr, + k)) { + *cur_out_ptr = k; + } + ++cur_in_ptr; + } + } + } + }); + } + return output; +} + +static Tensor& linear_out( + Tensor& output, + const Tensor& input, + const Tensor& weight, + const std::optional& bias_opt) { + TORCH_CHECK(!input.is_mkldnn()); + + auto bias = bias_opt.has_value() + ? c10::MaybeOwned::borrowed(*bias_opt) + : c10::MaybeOwned::owned(std::in_place); + + if (input.dim() == 2 && bias->defined()) { + // Fused op is marginally faster. + return at::cpu::addmm_out(output, *bias, input, weight.t()); + } + at::native::matmul_out(input, weight.t(), output); + if (bias->defined()) { + at::cpu::add_(output, *bias); + } + return output; +} + +static at::Tensor& mul_out( + at::Tensor& output, + const at::Tensor& self, + const at::Scalar& other) { + const auto& t_output = output.scalar_type(); + TORCH_CHECK(at::native::result_type(self, other) == t_output); + + auto self_sizes = self.sizes(); + at::native::resize_(output, self_sizes, std::nullopt); + + AT_DISPATCH_ALL_TYPES_AND2( + kHalf, kBFloat16, t_output, "mul_Scalar_out", [&]() { + using output_t = scalar_t; + output_t* output_ptr = output.mutable_data_ptr(); + + const int64_t num_elements = self.numel(); + const void* self_ptr = self.data_ptr(); + + at::parallel_for(0, num_elements, 1, [&](int64_t start, int64_t end) { + for (int64_t i = start; i < end; ++i) { + AT_DISPATCH_ALL_TYPES_AND2( + kHalf, kBFloat16, other.type(), "mul_Scalar_other", [&]() { + using other_t = scalar_t; + + output_t other_casted = static_cast( + reinterpret_cast(other.data_ptr())[0]); + + AT_DISPATCH_ALL_TYPES_AND2( + kHalf, + kBFloat16, + self.scalar_type(), + "mul_Scalar_self", + [&]() { + using self_t = scalar_t; + + output_ptr[i] = + other_casted * + static_cast( + reinterpret_cast(self_ptr)[i]); + }); + }); + } + }); + }); + + return output; +} + +} // namespace at::native + +namespace torch::nativert { + +C10_DEFINE_REGISTRY( + StaticallyDispatchedCPUKernelRegistry, + OpKernel, + const Node*, + c10::Device) + +namespace { + +// device & pin_memory matter only when CUDA is enabled. +static bool hasTensorWithOptions( + const c10::IValue& ivalue, + std::optional dtype, + std::optional layout) { + if (!ivalue.isTensor()) { + return false; + } + const auto& tensor = ivalue.toTensor(); + if (dtype == tensor.dtype().toScalarType() && + layout == tensor.options().layout_opt()) { + return true; + } + VLOG(1) << "tensor exists, but tensor options were different"; + return false; +} + +static bool hasTensorWithOptions( + const c10::IValue& ivalue, + std::optional dtype, + std::optional layout, + std::optional memory_format) { + return hasTensorWithOptions(ivalue, dtype, layout) && + (memory_format == ivalue.toTensor().options().memory_format_opt()); +} + +c10::MaybeOwned borrow_from_optional_tensor_ivalue( + const c10::IValue& iv) { + if (iv.isNone()) { + return c10::MaybeOwned::owned(std::in_place); + } + return c10::MaybeOwned::borrowed(iv.toTensor()); +} + +} // namespace + +REGISTER_CPU_KERNEL("torch.ops.aten.remainder.Tensor", aten_remainder_Tensor, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::remainder(self, KernelInput(1).toTensor()); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::remainder_out(out, self, KernelInput(1).toTensor()); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.remainder.Scalar", aten_remainder_Scalar, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::remainder(self, KernelInput(1).toScalar()); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::remainder_out(self, KernelInput(1).toScalar(), out); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.matmul.default", aten_matmul, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto& in1_t = KernelInput(1).toTensor(); + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::matmul(in0_t, in1_t); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::matmul_out(in0_t, in1_t, out_t); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.bmm.default", aten_bmm, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto& in1_t = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(in0_t); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::bmm_out(out_t, in0_t, in1_t); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.abs.default", aten_abs, { + const auto& in0_t = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::abs(in0_t); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::abs_out(in0_t, out_t); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.mul.Tensor", aten_mul, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto& in1_t = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::mul(in0_t, in1_t); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::mul_out(out_t, in0_t, in1_t); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.mul.Scalar", aten_mul_Scalar, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto& in1_t = KernelInput(1).toScalar(); + auto dtype = at::native::result_type(in0_t, in1_t); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(in0_t, dtype); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + KernelOutput(0) = at::native::mul_out(out_t, in0_t, in1_t); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.nan_to_num.default", aten_nan_to_num, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto in1_d = KernelInput(1).toOptional(); + const auto in2_d = KernelInput(2).toOptional(); + const auto in3_d = KernelInput(3).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::nan_to_num(in0_t, in1_d, in2_d, in3_d); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::nan_to_num_out(in0_t, in1_d, in2_d, in3_d, out_t); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.leaky_relu.default", aten_leaky_relu, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto in1_s = KernelInput(1).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::leaky_relu(in0_t, in1_s); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + at::cpu::leaky_relu_out(out_t, in0_t, in1_s); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.relu.default", aten_relu, { + const auto& in0_t = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(in0_t); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::threshold_out(out_t, in0_t, 0, 0); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.clone.default", aten_clone, { + const auto& src = KernelInput(0).toTensor(); + const auto& optional_memory_format = + KernelInput(1).toOptional(); + auto memory_format = + optional_memory_format.value_or(c10::MemoryFormat::Preserve); + /* + disable out_variant of clone for case with stride = 0 and + memory formats other than preserve. Perform dynamic allocation + instead of memory reuse for simpler implementation. We could, + in principle, figure out copy of strides. + */ + if ((at::has_internal_overlap(src.unsafeGetTensorImpl()) == + at::MemOverlap::Yes) || + (memory_format != c10::MemoryFormat::Preserve)) { + KernelOutput(0) = at::native::clone(src, memory_format); + return; + } + if (KernelOutput(0).isNone()) { + if (src.is_non_overlapping_and_dense()) { + // Copy all strides + KernelOutput(0) = + at::empty_strided(src.sizes(), src.strides(), src.options()); + } else { + memory_format = src.suggest_memory_format(); + KernelOutput(0) = create_empty_from(src, memory_format); + } + } + auto& out_t = KernelOutput(0).toTensor(); + at::native::resize_impl_cpu_( + out_t.unsafeGetTensorImpl(), src.sizes(), src.strides()); + at::native::copy_(out_t, src, false); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.index.Tensor", aten_index, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto in1_l = + at::native::toListOfOptionalTensors(KernelInput(1).toListRef()); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::index(in0_t, in1_l); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::index_out(out_t, in0_t, in1_l); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.index_select.default", aten_index_select, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto& index = KernelInput(2).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::index_select_cpu_(self, dim, index); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::index_select_out_cpu_(self, dim, index, out); +}) + +REGISTER_CPU_KERNEL( + "torch.ops.aten.pow.Tensor_Tensor", + aten_pow_Tensor_Tensor, + { + if (KernelOutput(0).isNone()) { + const auto& in0_t = KernelInput(0).toTensor(); + auto dtype = at::native::result_type(in0_t, KernelInput(1).toTensor()); + KernelOutput(0) = create_empty_from(in0_t, dtype); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::pow_out( + out_t, KernelInput(0).toTensor(), KernelInput(1).toTensor()); + }) + +REGISTER_CPU_KERNEL("torch.ops.aten.pow.Scalar", aten_pow_Scalar, { + if (KernelOutput(0).isNone()) { + const auto& in1_t = KernelInput(1).toTensor(); + auto dtype = at::native::result_type(KernelInput(0).toScalar(), in1_t); + KernelOutput(0) = at::native::empty_like( + in1_t, + dtype, + in1_t.options().layout_opt(), + in1_t.options().device_opt(), + in1_t.options().pinned_memory_opt(), + at::MemoryFormat::Preserve); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::pow_out(out_t, KernelInput(0).toScalar(), KernelInput(1).toTensor()); +}) + +REGISTER_CPU_KERNEL( + "torch.ops.aten.pow.Tensor_Scalar", + aten_pow_Tensor_Scalar, + { + if (KernelOutput(0).isNone()) { + const auto& in0_t = KernelInput(0).toTensor(); + auto dtype = at::native::result_type(in0_t, KernelInput(1).toScalar()); + KernelOutput(0) = at::native::empty_like( + in0_t, + dtype, + in0_t.options().layout_opt(), + in0_t.options().device_opt(), + in0_t.options().pinned_memory_opt(), + at::MemoryFormat::Preserve); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::pow_out( + out_t, KernelInput(0).toTensor(), KernelInput(1).toScalar()); + }) + +REGISTER_CPU_KERNEL("torch.ops.aten.sum.default", aten_sum_default, { + // if (n->inputs().size() != 2 && n->inputs().size() != 4) { + // return nullptr; + // } + const at::Tensor& self = KernelInput(0).toTensor(); + auto dtype = KernelInput(1).toOptional(); + std::vector dim = {}; + bool keepdim = false; + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::sum(self, dim, keepdim, dtype); + } else { + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::sum_out(out, self, dim, keepdim, dtype); + } +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.sum.dim_IntList", aten_sum_dim_IntList, { + // if (n->inputs().size() != 2 && n->inputs().size() != 4) { + // return nullptr; + // } + const at::Tensor& self = KernelInput(0).toTensor(); + auto dim = KernelInput(1).toDimVector(); + auto keepdim = KernelInput(2).toBool(); + auto dtype = KernelInput(3).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::sum(self, dim, keepdim, dtype); + } else { + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::sum_out(out, self, dim, keepdim, dtype); + } +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.mean.dim", aten_mean_dim, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toDimVector(); + const bool keepdim = KernelInput(2).toBool(); + const auto dtype = KernelInput(3).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = + create_empty_from(self, dtype.value_or(self.dtype().toScalarType())); + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::mean_out(out, self, dim, keepdim, dtype); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.mean.default", aten_mean_default, { + const auto& self = KernelInput(0).toTensor(); + const auto dtype = KernelInput(1).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = + create_empty_from(self, dtype.value_or(self.dtype().toScalarType())); + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::mean_out(out, self, /*dim=*/{}, /*keepdim=*/false, dtype); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.max.other", aten_max_other, { + const auto& self = KernelInput(0).toTensor(); + const auto& other = KernelInput(1).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::max(self, other); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::max_out(self, other, out); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.max.default", aten_max_default, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(self); + } + auto& value = KernelOutput(0).toTensor(); + fastResizeToZero(value); + at::cpu::amax_out(value, self); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.sign.Tensor", aten_sign_Tensor, { + const auto& in0_t = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::sign(in0_t); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::sign_out(out_t, in0_t); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.log.default", aten_log, { + const auto& in0_t = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::log(in0_t); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::log_out(out_t, in0_t); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.sub.Tensor", aten_sub_Tensor, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto& in1_t = KernelInput(1).toTensor(); + const auto alpha = KernelInput(2).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::sub(in0_t, in1_t, alpha); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::sub_out(out_t, in0_t, in1_t, alpha); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.sub.Scalar", aten_sub, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto& in1_t = + at::native::wrapped_scalar_tensor(KernelInput(1).toScalar()); + const auto alpha = KernelInput(2).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::sub(in0_t, in1_t, alpha); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::sub_out(out_t, in0_t, in1_t, alpha); +}) + +// TODO: support clamp_min.Tensor(Tensor self, Tensor min) -> Tensor +// Missing Test Coverage +REGISTER_CPU_KERNEL( + "torch.ops.aten.clamp_min.default", + aten_clamp_min_default, + { + const auto& in0_t = KernelInput(0).toTensor(); + const auto in1_s = KernelInput(1).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::clamp_min(in0_t, in1_s); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::clamp_min_out(out_t, in0_t, in1_s); + }) + +REGISTER_CPU_KERNEL("torch.ops.aten.argmin.default", aten_argmin, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toOptional(); + const auto keepdim = KernelInput(2).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::argmin(in0_t, dim, keepdim); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + if (in0_t.is_contiguous() && dim.has_value()) { + at::native::c2_argmin_out(out_t, in0_t, dim.value(), keepdim); + return; + } + at::cpu::argmin_out(out_t, in0_t, dim, keepdim); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.softmax.int", aten_softmax_int, { + const auto& in_t = KernelInput(0).toTensor(); + const auto& dim = KernelInput(1).toInt(); + const auto& dtype = KernelInput(2).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::softmax(in_t, dim, dtype); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + auto half_to_float = in_t.scalar_type() == at::ScalarType::Half && + dtype == at::ScalarType::Float; + at::cpu::_softmax_out(out_t, in_t, dim, half_to_float); +}) + +REGISTER_CPU_KERNEL( + "torch.ops.aten.norm.ScalarOpt_dtype", + aten_norm_ScalarOpt_dtype, + { + const auto& in0_t = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(in0_t); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + const auto in1_s = KernelInput(1).toOptional(); + at::cpu::norm_outf( + in0_t, + in1_s, + c10::IntArrayRef{}, + false, + KernelInput(2).toScalarType(), + out_t); + }) + +REGISTER_CPU_KERNEL("torch.ops.aten.full.default", aten_full, { + const auto& size = KernelInput(0).toDimVector(); + const auto fill_value = KernelInput(1).toScalar(); + const auto dtype = KernelInput(2).toOptional(); + const auto layout = KernelInput(3).toOptional(); + if (!hasTensorWithOptions(KernelOutput(0), dtype, layout)) { + const auto device = KernelInput(4).toOptional(); + const auto pin_memory = KernelInput(5).toOptional(); + KernelOutput(0) = + at::native::full(size, fill_value, dtype, layout, device, pin_memory); + return; + } + KernelOutput(0) = + at::native::full_out(size, fill_value, KernelOutput(0).toTensor()); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.ones.default", aten_ones, { + const auto size = KernelInput(0).toDimVector(); + if (KernelOutput(0).isNone()) { + const auto dtype = KernelInput(1).toOptional(); + const auto layout = KernelInput(2).toOptional(); + const auto device = KernelInput(3).toOptional(); + const auto pin_memory = KernelInput(4).toOptional(); + KernelOutput(0) = at::native::ones(size, dtype, layout, device, pin_memory); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::ones_out(size, out_t); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.ones_like.default", aten_ones_like, { + const auto& self = KernelInput(0).toTensor(); + const auto dtype = KernelInput(1).toOptional(); + const auto layout = KernelInput(2).toOptional(); + const auto device = KernelInput(3).toOptional(); + const auto pin_memory = KernelInput(4).toOptional(); + const auto memory_format = KernelInput(5).toOptional(); + if (!hasTensorWithOptions(KernelOutput(0), dtype, layout, memory_format)) { + KernelOutput(0) = at::native::ones_like( + self, dtype, layout, device, pin_memory, memory_format); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::ones_out(self.sizes(), out_t); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.zeros.default", aten_zeros, { + const auto size = KernelInput(0).toDimVector(); + const auto dtype = KernelInput(1).toOptional(); + const auto layout = KernelInput(2).toOptional(); + if (!hasTensorWithOptions(KernelOutput(0), dtype, layout)) { + KernelOutput(0) = at::compositeexplicitautograd::zeros( + size, dtype, layout, std::nullopt, std::nullopt); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::compositeexplicitautograd::zeros_out(out_t, size); +}) + +REGISTER_CPU_KERNEL( + "torch.ops.aten.linalg_norm.default", + aten_linalg_norm_default, + { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(2).toDimVector(); + const auto keepdim = KernelInput(3).toBool(); + const auto dtype = KernelInput(4).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::linalg_norm( + self, KernelInput(1).toOptional(), dim, keepdim, dtype); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::linalg_norm_out( + self, + KernelInput(1).toOptional(), + dim, + keepdim, + dtype, + out); + }) + +REGISTER_CPU_KERNEL("torch.ops.aten.linalg_norm.ord_str", aten_linalg_norm, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(2).toDimVector(); + const auto keepdim = KernelInput(3).toBool(); + const auto dtype = KernelInput(4).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::linalg_norm( + self, KernelInput(1).toStringView(), dim, keepdim, dtype); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::linalg_norm_out( + self, KernelInput(1).toStringRef(), dim, keepdim, dtype, out); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.cat.default", aten_cat, { + const auto inputs = KernelInput(0).toTensorVector(); + TORCH_CHECK(!inputs.empty(), "concat expects non-empty tensor list"); + const auto dim = KernelInput(1).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::cat(inputs, dim); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::cat_outf(inputs, dim, out); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.cumsum.default", aten_cumsum, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto dtype = KernelInput(2).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::cumsum(self, dim, dtype); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::cpu::cumsum_out(out, self, dim, dtype); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.nonzero.default", aten_nonzero, { + const auto& self = KernelInput(0).toTensor(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::nonzero_cpu(self); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::nonzero_out_cpu(self, out); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.addmm.default", aten_addmm, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto& in1_t = KernelInput(1).toTensor(); + const auto& in2_t = KernelInput(2).toTensor(); + const auto in3_s = KernelInput(3).toScalar(); + const auto in4_s = KernelInput(4).toScalar(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::cpu::addmm(in0_t, in1_t, in2_t, in3_s, in4_s); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::cpu::addmm_out(out_t, in0_t, in1_t, in2_t, in3_s, in4_s); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.narrow_copy.default", aten_narrow_copy, { + const auto& self = KernelInput(0).toTensor(); // self + const auto dim = KernelInput(1).toInt(); // dim + int64_t start = 0; + if (KernelInput(2).isScalar()) { + start = KernelInput(2).toInt(); + } else { + auto& t = KernelInput(2).toTensor(); + start = t.item(); + } + auto length = KernelInput(3).toInt(); // length + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = + at::native::narrow_copy_dense_cpu(self, dim, start, length); + return; + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::narrow_copy_dense_cpu_out(self, dim, start, length, out); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.repeat.default", aten_repeat, { + const auto& self = KernelInput(0).toTensor(); + const auto repeats = KernelInput(1).toDimVector(); + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::repeat(self, repeats); + return; + } + at::Tensor& out = KernelOutput(0).toTensor(); + at::native::repeat_out(out, self, repeats); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.max.dim", aten_max_dim, { + const auto& self = KernelInput(0).toTensor(); + auto dim = KernelInput(1).toInt(); + const auto keepdim = KernelInput(2).toBool(); + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(self); + } + + if (KernelOutput(1).isNone()) { + KernelOutput(1) = create_empty_from(self, at::kLong); + } + + auto& values = KernelOutput(0).toTensor(); + auto& indices = KernelOutput(1).toTensor(); + fastResizeToZero(values); + fastResizeToZero(indices); + at::cpu::max_out(values, indices, self, dim, keepdim); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.layer_norm.default", aten_layer_norm, { + // ignore KernelInput(5): `bool cudnn_enable=True` + const auto& input_t = KernelInput(0).toTensor(); + const auto normalized_shape = KernelInput(1).toDimVector(); + float eps = KernelInput(4).toDouble(); + + c10::MaybeOwned weight_maybe_owned = + borrow_from_optional_tensor_ivalue(KernelInput(2)); + const at::Tensor& weight = *weight_maybe_owned; + c10::MaybeOwned bias_maybe_owned = + borrow_from_optional_tensor_ivalue(KernelInput(3)); + const at::Tensor& bias = *bias_maybe_owned; + + auto M_N = at::native::_check_layer_norm_inputs( + input_t, normalized_shape, weight, bias); + auto M = M_N.first; + auto N = M_N.second; + auto X = input_t.expect_contiguous(); + auto gamma = weight.expect_contiguous(); + auto beta = bias.expect_contiguous(); + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::empty_like( + *X, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, + at::MemoryFormat::Contiguous); + } else { + at::native::resize_(KernelOutput(0).toTensor(), X->sizes(), std::nullopt); + } + at::Tensor& out = KernelOutput(0).toTensor(); + at::native::layer_norm_cpu_out(out, *X, *gamma, *beta, eps, M, N); +}) + +REGISTER_CPU_KERNEL( + "torch.ops.aten.norm.ScalarOpt_dim_dtype", + aten_norm_ScalarOpt_dim_dtype, + { + const auto& in0_t = KernelInput(0).toTensor(); + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(in0_t); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + + const auto in1_s = KernelInput(1).toOptional(); + at::cpu::norm_outf( + in0_t, + in1_s, + KernelInput(2).toDimVector(), // dim + KernelInput(3).toBool(), // keepdim + KernelInput(4).toScalarType(), // dtype + out_t); + }) + +REGISTER_CPU_KERNEL( + "torch.ops.aten.norm.ScalarOpt_dim", + aten_norm_ScalarOpt_dim, + { + const auto& in0_t = KernelInput(0).toTensor(); + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(in0_t); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + + const auto in1_s = KernelInput(1).toOptional(); + at::cpu::norm_outf( + in0_t, + in1_s, + KernelInput(2).toDimVector(), // dim + KernelInput(3).toBool(), // keepdim + out_t); + }) + +REGISTER_CPU_KERNEL("torch.ops.aten.full_like.default", aten_full_like, { + const auto in1_s = KernelInput(1).toScalar(); + const auto& in0_t = KernelInput(0).toTensor(); + const auto dtype = KernelInput(2).toOptional(); + const auto layout = KernelInput(3).toOptional(); + if (!hasTensorWithOptions(KernelOutput(0), dtype, layout)) { + const auto device = KernelInput(4).toOptional(); + const auto pin_memory = KernelInput(5).toOptional(); + const auto memory_format = KernelInput(6).toOptional(); + + KernelOutput(0) = at::native::empty_like( + in0_t, dtype, layout, device, pin_memory, memory_format); + } + auto& out_t = KernelOutput(0).toTensor(); + at::native::resize_(out_t, in0_t.sizes(), std::nullopt); + at::native::fill_out(out_t, in1_s); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.linear.default", aten_linear, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto& in1_t = KernelInput(1).toTensor(); + auto in2_t = KernelInput(2).toOptional(); + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::linear(in0_t, in1_t, in2_t); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::linear_out(out_t, in0_t, in1_t, in2_t); +}) + +REGISTER_CPU_KERNEL("torch.ops.aten.where.self", aten_where, { + const auto& cond = KernelInput(0).toTensor(); + const auto& self = KernelInput(1).toTensor(); + const auto& other = KernelInput(2).toTensor(); + + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(self); + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::native::where_self_out(cond, self, other, out); +}) + +REGISTER_CPU_KERNEL( + "torch.ops.quantized.embedding_bag_byte_rowwise_offsets.default", + quantized_embedding_bag_byte_rowwise_offsets, + { + const auto& weight = KernelInput(0).toTensor(); + const auto& indices = KernelInput(1).toTensor(); + const auto offsets = KernelInput(2).toOptional(); + const auto pruned_weights = KernelInput(5).toBool(); + const auto per_sample_weights = KernelInput(6).toOptional(); + const auto compressed_indices_mapping = + KernelInput(7).toOptional(); + const auto include_last_offset = KernelInput(8).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(weight, at::kFloat); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::embedding_bag_byte_rowwise_offsets_out( + out_t, + weight, + indices, + offsets, + false, // unused scale_grad_by_freq + 0, // unused mode + pruned_weights, + per_sample_weights, + compressed_indices_mapping, + include_last_offset); + }) + +REGISTER_CPU_KERNEL( + "torch.ops.quantized.embedding_bag_4bit_rowwise_offsets.default", + quantized_embedding_bag_4bit_rowwise_offsets, + { + const auto& weight = KernelInput(0).toTensor(); + const auto& indices = KernelInput(1).toTensor(); + const auto offsets = KernelInput(2).toOptional(); + const auto pruned_weights = KernelInput(5).toBool(); + const auto per_sample_weights = KernelInput(6).toOptional(); + const auto compressed_indices_mapping = + KernelInput(7).toOptional(); + const auto include_last_offset = KernelInput(8).toBool(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(weight, at::kFloat); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::embedding_bag_4bit_rowwise_offsets_out( + out_t, + weight, + indices, + offsets, + false, // unused scale_grad_by_freq + 0, // unused mode + pruned_weights, + per_sample_weights, + compressed_indices_mapping, + include_last_offset); + }) + +REGISTER_CPU_KERNEL( + "torch.ops.quantized.linear_dynamic_fp16.default", + quantized_linear_dynamic_fp16, + { + const auto& in_0 = KernelInput(0).toTensor(); + + if (auto& out_0 = KernelOutput(0); out_0.isNone()) { + out_0 = create_empty_from(in_0, at::kFloat); + } + + auto& out_0 = KernelOutput(0).toTensor(); + fastResizeToZero(out_0); + + KernelInput(1).toCustomClass()->apply_dynamic_out( + in_0, out_0, /* reduce_range= */ false); + }) + +REGISTER_CPU_KERNEL( + "torch.ops.quantized.linear_relu_dynamic_fp16.default", + quantized_linear_relu_dynamic_fp16, + { + const auto& in_0 = KernelInput(0).toTensor(); + + if (auto& out_0 = KernelOutput(0); out_0.isNone()) { + out_0 = create_empty_from(in_0, at::kFloat); + } + + auto& out_0 = KernelOutput(0).toTensor(); + fastResizeToZero(out_0); + + KernelInput(1) + .toCustomClass() + ->apply_dynamic_out(in_0, out_0, /* reduce_range= */ false) + .relu_(); + }) + +REGISTER_CPU_KERNEL( + "torch.ops.quantized.linear.default", + quantized_linear_default, + { + const auto& in_0 = KernelInput(0).toTensor(); + const auto w_prepack = + KernelInput(1).toCustomClass(); + const auto output_scale = KernelInput(2).toDouble(); + const auto output_zero_point = KernelInput(3).toInt(); + if (auto& out_t = KernelOutput(0); out_t.isNone()) { + out_t = at::native::empty_affine_quantized( + {0}, + c10::kQUInt8, + std::nullopt, + c10::kCPU, + false, + output_scale, + output_zero_point, + std::nullopt); + } + auto& out_tensor = KernelOutput(0).toTensor(); + fastResizeToZero(out_tensor); + w_prepack->apply_out(in_0, output_scale, output_zero_point, out_tensor); + }) + +REGISTER_CPU_KERNEL("torch.ops.aten.logit.default", aten_logit, { + const auto& in0_t = KernelInput(0).toTensor(); + const auto& in1_d = KernelInput(1).toOptional(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(in0_t); + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::logit_out(in0_t, in1_d, out_t); +}) + +REGISTER_CPU_KERNEL( + "torch.ops.aten.slice_scatter.default", + aten_slice_scatter, + { + const auto& self = KernelInput(0).toTensor(); + const auto& src = KernelInput(1).toTensor(); + const int64_t dim = KernelInput(2).toInt(); + const auto& start = KernelInput(3).toOptional(); + const auto& end = KernelInput(4).toOptional(); + int64_t step = KernelInput(5).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = create_empty_from(self); + } + auto& out = KernelOutput(0).toTensor(); + fastResizeToZero(out); + at::slice_scatter_out(out, self, src, dim, start, end, step); + }) + +REGISTER_CPU_KERNEL( + "torch.ops.quantized.embedding_bag_byte_unpack.default", + quantized_embedding_bag_byte_unpack_default, + { + const auto& weight = KernelInput(0).toTensor(); + if (auto& out = KernelOutput(0); out.isNone()) { + out = at::empty( + {}, + weight.options().dtype(at::kFloat), + weight.suggest_memory_format()); + } + auto& out_tensor = KernelOutput(0).toTensor(); + fastResizeToZero(out_tensor); + at::native::qembeddingbag_byte_unpack_out(out_tensor, weight); + }) + +REGISTER_CPU_KERNEL( + "torch.ops.quantized.embedding_bag_byte_prepack.default", + embedding_bag_byte_prepack_default, + { + const auto& weight = KernelInput(0).toTensor(); + if (auto& out_t = KernelOutput(0); out_t.isNone()) { + KernelOutput(0) = at::native::qembeddingbag_byte_prepack(weight); + return; + } + auto& out_tensor = KernelOutput(0).toTensor(); + fastResizeToZero(out_tensor); + at::native::qembeddingbag_byte_prepack_out(out_tensor, weight); + }) + +REGISTER_CPU_KERNEL("torch.ops.aten.stack.default", aten_stack, { + const auto& inputs = KernelInput(0).toTensorVector(); + const auto dim = KernelInput(1).toInt(); + if (KernelOutput(0).isNone()) { + KernelOutput(0) = at::native::_stack_cpu(inputs, dim); + return; + } + auto& out_t = KernelOutput(0).toTensor(); + fastResizeToZero(out_t); + at::native::_stack_out_cpu(inputs, dim, out_t); +}) + +class OpKernel_aten__to_copy : public C10Kernel { + public: + explicit OpKernel_aten__to_copy(const Node* node, c10::Device device) + : C10Kernel( + node, + device, + torch::nativert::OpKernelKind::kStaticDispatchKernel, + torch::nativert::AliasingSpec{ + {/* input_idx = */ 0, /* output_idx = */ 0}}) { + dtype_ = attribute(1).toOptional(); + layout_ = attribute(2).toOptional(); + device_ = attribute(3).toOptional(); + pin_memory_ = attribute(4).toOptional(); + non_blocking_ = attribute(5).toBool(); + memory_format_ = attribute(6).toOptional(); + + has_memory_format_ = memory_format_.has_value(); + + if (memory_format_.has_value()) { + TORCH_CHECK( + memory_format_.value() != c10::MemoryFormat::ChannelsLast && + memory_format_.value() != c10::MemoryFormat::ChannelsLast3d, + "Static Kernel for aten._to_copy doesn't correctly handle the ChannelsLast(3d) memory format. If you are running into this error, please report to nativert oncall."); + } + + if (device_.has_value()) { + TORCH_CHECK( + device_.value().is_cpu(), + "Static kernel for aten._to_copy only supports CPU device, but got ", + device_.value()); + } + } + + void computeInternal(ExecutionFrame& executionFrame) const override final { + const auto& self = KernelInput(0).toTensor(); + auto& out = KernelOutput(0); + + // skip if the _to_copy is a no-op + if (dtype_.has_value() && self.dtype() == dtype_.value() && + !has_memory_format_ && !device_.has_value() && !layout_.has_value()) { + if (out.isNone()) { + out = at::native::alias(self); + return; + } + + auto* in_t = self.unsafeGetTensorImpl(); + auto* out_t = out.toTensor().unsafeGetTensorImpl(); + + // it's possible that the input storage has been updated + if (!out_t->storage().is_alias_of(in_t->storage())) { + out_t->set_storage_keep_dtype(in_t->storage()); + } + + // in case in was re-sized/strided from the prev. impl + // we need to make sure the metadata is consistent between + // in_t and out_t + + if (in_t->storage_offset() != out_t->storage_offset()) { + out_t->set_storage_offset(in_t->storage_offset()); + } + + if (in_t->sizes_and_strides() != out_t->sizes_and_strides()) { + out_t->set_sizes_and_strides(self.sizes(), self.strides()); + } + + return; + } + + std::optional memory_format = + c10::MemoryFormat::Preserve; + if (has_memory_format_) { + memory_format = memory_format_.value_or(c10::MemoryFormat::Preserve); + } + + bool copy_strides = false; + if (memory_format == c10::MemoryFormat::Preserve) { + if (self.is_non_overlapping_and_dense()) { + memory_format = std::nullopt; + copy_strides = true; + } else { + memory_format = self.suggest_memory_format(); + } + } + + bool need_to_allocate_output = true; + if (out.isTensor()) { + const auto& existing_output = out.toTensor(); + if ((has_memory_format_ && + !existing_output.is_contiguous( + memory_format.value_or(c10::MemoryFormat::Contiguous)))) { + need_to_allocate_output = true; + } else { + need_to_allocate_output = false; + } + } + + // See Note [Explicit nullopt MemoryFormat argument] + // Can't use size {0} if memory_format is ChannelLast + if (need_to_allocate_output) { + out = at::detail::empty_cpu( + self.sizes(), + dtype_.value_or(self.scalar_type()), + layout_, + device_, + std::nullopt, + memory_format); + } else { + if (has_memory_format_) { + memory_format = memory_format_.value_or(c10::MemoryFormat::Preserve); + } else { + memory_format = c10::MemoryFormat::Preserve; + } + } + + copy_strides = copy_strides || + (memory_format == c10::MemoryFormat::Preserve && + self.is_non_overlapping_and_dense()); + + auto& out_t = out.toTensor(); + fastResizeToZero(out_t); + at::native::to_copy_out( + out_t, self, non_blocking_, copy_strides, memory_format); + } + + private: + std::optional dtype_; + std::optional layout_; + std::optional device_; + std::optional pin_memory_; + bool non_blocking_ = false; + std::optional memory_format_; + bool has_memory_format_; +}; + +C10_REGISTER_TYPED_CLASS( + StaticallyDispatchedCPUKernelRegistry, + "torch.ops.aten._to_copy.default", + OpKernel_aten__to_copy) + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/KernelRegistry.h b/torch/nativert/kernels/KernelRegistry.h new file mode 100644 index 000000000000..03293871fef2 --- /dev/null +++ b/torch/nativert/kernels/KernelRegistry.h @@ -0,0 +1,122 @@ +#pragma once + +#include +#include +#include + +namespace torch::nativert { + +TORCH_DECLARE_REGISTRY( + StaticallyDispatchedCPUKernelRegistry, + OpKernel, + const Node*, + c10::Device); + +#define REGISTER_CPU_KERNEL(name, id, ...) \ + class OpKernel_##id : public C10Kernel { \ + public: \ + OpKernel_##id(const Node* node, c10::Device device) \ + : C10Kernel( \ + node, \ + device, \ + torch::nativert::OpKernelKind::kStaticDispatchKernel) {} \ + void computeInternal(torch::nativert::ExecutionFrame& executionFrame) \ + const override final { \ + __VA_ARGS__; \ + } \ + }; \ + C10_REGISTER_TYPED_CLASS( \ + StaticallyDispatchedCPUKernelRegistry, name, OpKernel_##id) + +#define ALIASING_SPEC(...) __VA_ARGS__ + +#define REGISTER_ALIASING_CPU_KERNEL(name, id, aliasing_spec, ...) \ + class OpKernel_##id : public C10Kernel { \ + public: \ + OpKernel_##id(const Node* node, c10::Device device) \ + : C10Kernel( \ + node, \ + device, \ + torch::nativert::OpKernelKind::kNativeStaticDispatchKernel, \ + aliasing_spec) {} \ + void computeInternal(torch::nativert::ExecutionFrame& executionFrame) \ + const override final { \ + __VA_ARGS__; \ + } \ + }; \ + C10_REGISTER_TYPED_CLASS( \ + StaticallyDispatchedCPUKernelRegistry, name, OpKernel_##id) + +#define REGISTER_NATIVE_CPU_KERNEL(name, id, ...) \ + class OpKernel_##id : public C10Kernel { \ + public: \ + OpKernel_##id(const Node* node, c10::Device device) \ + : C10Kernel( \ + node, \ + device, \ + torch::nativert::OpKernelKind::kNativeStaticDispatchKernel) {} \ + void computeInternal(torch::nativert::ExecutionFrame& executionFrame) \ + const override final { \ + __VA_ARGS__; \ + } \ + }; \ + C10_REGISTER_TYPED_CLASS( \ + StaticallyDispatchedCPUKernelRegistry, name, OpKernel_##id) + +inline at::Tensor create_empty_from(const at::Tensor& t) { + return at::detail::empty_cpu( + {0}, + c10::typeMetaToScalarType(t.dtype()), + t.layout(), + t.device(), + std::nullopt, + std::nullopt); +} + +inline at::Tensor create_empty_from( + const at::Tensor& t, + c10::ScalarType dtype) { + return at::detail::empty_cpu( + {0}, dtype, t.layout(), t.device(), std::nullopt, std::nullopt); +} + +inline at::Tensor create_empty_from(const at::Tensor& t, c10::Device device) { + return at::detail::empty_cpu( + {0}, + c10::typeMetaToScalarType(t.dtype()), + t.layout(), + device, + std::nullopt, + std::nullopt); +} +inline at::Tensor create_empty_from(const at::Tensor& t, c10::Layout layout) { + return at::detail::empty_cpu( + {0}, + c10::typeMetaToScalarType(t.dtype()), + layout, + t.device(), + std::nullopt, + std::nullopt); +} + +inline at::Tensor create_empty_from( + const at::Tensor& t, + c10::MemoryFormat memory_format) { + return at::detail::empty_cpu( + {0}, + c10::typeMetaToScalarType(t.dtype()), + t.layout(), + t.device(), + std::nullopt, + memory_format); +} + +inline at::Tensor create_empty_from( + const at::Tensor& t, + c10::ScalarType dtype, + c10::MemoryFormat memory_format) { + return at::detail::empty_cpu( + {0}, dtype, t.layout(), t.device(), std::nullopt, memory_format); +} + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/NativeKernels.cpp b/torch/nativert/kernels/NativeKernels.cpp new file mode 100644 index 000000000000..7acd82102266 --- /dev/null +++ b/torch/nativert/kernels/NativeKernels.cpp @@ -0,0 +1,113 @@ +#include + +#include +#include +#include + +namespace torch::nativert { + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.slice.Tensor", aten_slice_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto& dim = KernelInput(1).toInt(); + const auto& start = KernelInput(2).toOptional(); + const auto& end = KernelInput(3).toOptional(); + const auto& step = KernelInput(4).toInt(); + KernelOutput(0) = at::native::slice(self, dim, start, end, step); +}) + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.sym_size.int", aten_sym_size_int, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + auto& out = KernelOutput(0); + TORCH_CHECK(dim >= 0 && dim < self.dim(), "Invalid dimension"); + out = self.sym_size(dim); +}) + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.reshape.default", aten_reshape, { + const auto& self = KernelInput(0).toTensor(); + const auto& shape = KernelInput(1).toIntVector(); + KernelOutput(0) = at::native::reshape(self, shape); +}) + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.view.default", aten_view, { + const auto& self = KernelInput(0).toTensor(); + const auto& size = KernelInput(1).toIntVector(); + KernelOutput(0) = at::native::view(self, size); +}) + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.permute.default", aten_permute, { + const auto& self = KernelInput(0).toTensor(); + const auto& dims = KernelInput(1).toDimVector(); + KernelOutput(0) = at::native::permute(self, dims); +}) + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.select.int", aten_select, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + const auto index = KernelInput(2).toInt(); + KernelOutput(0) = at::native::select(self, dim, index); +}) + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.split.Tensor", aten_split_Tensor, { + const auto& self = KernelInput(0).toTensor(); + const auto split_size = KernelInput(1).toInt(); + const auto dim = KernelInput(2).toInt(); + KernelOutput(0) = at::native::split(self, split_size, dim); +}) + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.split_with_sizes.default", + aten_split_with_sizes, + { + const auto& self = KernelInput(0).toTensor(); + const auto& split_sizes = KernelInput(1).toIntList(); + const auto dim = KernelInput(2).toInt(); + KernelOutput(0) = + at::native::split_with_sizes(self, split_sizes.vec(), dim); + }) + +REGISTER_NATIVE_CPU_KERNEL( + "torch.ops.aten.tensor_split.sections", + aten_tensor_split_sections, + { + const auto& self = KernelInput(0).toTensor(); + const auto sections = KernelInput(1).toInt(); + const auto dim = KernelInput(2).toInt(); + KernelOutput(0) = + at::native::tensor_split_sections_symint(self, sections, dim); + }) + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.item.default", aten_item, { + const auto& self = KernelInput(0).toTensor(); + KernelOutput(0) = at::native::item(self); +}) + +REGISTER_NATIVE_CPU_KERNEL("torch.ops.aten.narrow.default", aten_narrow, { + const auto& self = KernelInput(0).toTensor(); + const auto dim = KernelInput(1).toInt(); + int64_t start = 0; + if (KernelInput(2).isScalar()) { + start = KernelInput(2).toInt(); + } else { + auto& t = KernelInput(2).toTensor(); + start = t.item(); + } + const auto length = KernelInput(3).toInt(); + TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor."); + auto cur_size = self.sizes()[dim]; + if (start != cur_size && start < 0) { + start = at::maybe_wrap_dim(start, cur_size); + } + TORCH_CHECK( + length >= 0 && start <= cur_size - length, + "start (", + start, + ") + length (", + length, + ") exceeds dimension size (", + cur_size, + ")."); + KernelOutput(0) = at::native::slice(self, dim, start, start + length, 1); +}) + +} // namespace torch::nativert diff --git a/torch/nativert/kernels/PrimKernelRegistry.cpp b/torch/nativert/kernels/PrimKernelRegistry.cpp index e6f69634a71b..b9071c8ecc4e 100644 --- a/torch/nativert/kernels/PrimKernelRegistry.cpp +++ b/torch/nativert/kernels/PrimKernelRegistry.cpp @@ -10,7 +10,7 @@ namespace torch::nativert { -C10_DEFINE_REGISTRY(PrimKernelRegistry, OpKernel, const Node*); +C10_DEFINE_REGISTRY(PrimKernelRegistry, OpKernel, const Node*) namespace { @@ -57,7 +57,7 @@ class OpKernel_prim_listpack : public OpKernel { C10_REGISTER_TYPED_CLASS( PrimKernelRegistry, "prim.ListPack", - OpKernel_prim_listpack); + OpKernel_prim_listpack) REGISTER_PRIM_KERNEL("prim.ListUnpack", prim_listunpack, { RECORD_USER_SCOPE("nativert::OpKernel_prim_listunpack"); @@ -65,11 +65,11 @@ REGISTER_PRIM_KERNEL("prim.ListUnpack", prim_listunpack, { for (const auto& [i, ivalue] : c10::enumerate(inputListRef)) { KernelOutput(i) = ivalue; } -}); +}) // Noop for input and output -REGISTER_PRIM_KERNEL("prim.Input", prim_input, {}); -REGISTER_PRIM_KERNEL("prim.Output", prim_output, {}); +REGISTER_PRIM_KERNEL("prim.Input", prim_input, {}) +REGISTER_PRIM_KERNEL("prim.Output", prim_output, {}) namespace { @@ -114,7 +114,7 @@ class OpKernel_variadic_concat : public OpKernel { C10_REGISTER_TYPED_CLASS( PrimKernelRegistry, "prim.VarConcat", - OpKernel_variadic_concat); + OpKernel_variadic_concat) namespace { @@ -158,6 +158,6 @@ class OpKernel_variadic_stack : public OpKernel { C10_REGISTER_TYPED_CLASS( PrimKernelRegistry, "prim.VarStack", - OpKernel_variadic_stack); + OpKernel_variadic_stack) } // namespace torch::nativert diff --git a/torch/nativert/kernels/PrimKernelRegistry.h b/torch/nativert/kernels/PrimKernelRegistry.h index 89e9c29e7dcb..f050ff79b86f 100644 --- a/torch/nativert/kernels/PrimKernelRegistry.h +++ b/torch/nativert/kernels/PrimKernelRegistry.h @@ -21,7 +21,7 @@ TORCH_DECLARE_REGISTRY(PrimKernelRegistry, OpKernel, const Node*); __VA_ARGS__; \ } \ }; \ - C10_REGISTER_TYPED_CLASS(PrimKernelRegistry, name, OpKernel_##id); + C10_REGISTER_TYPED_CLASS(PrimKernelRegistry, name, OpKernel_##id) inline bool checkResizedDataPtr(at::Tensor& t) { auto const prev_data_ptr = t.data_ptr(); diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index ce592c1ed342..7dc66696d110 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -9,10 +9,22 @@ import operator import warnings from enum import Enum -from typing import Any, Callable, Optional, Union +from typing import Callable, Optional, Union import torch from torch import Tensor + + +try: + from typing import TypedDict +except ImportError: + from typing_extensions import TypedDict + +try: + from typing import NotRequired +except ImportError: + from typing_extensions import NotRequired + from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop from torch._higher_order_ops.utils import _set_compilation_env from torch._prims_common import DeviceLikeType @@ -24,9 +36,40 @@ from torch.utils._pytree import tree_map_only +# Private debug flag to disable internal compilation wrapping for debugging purposes. +# WARNING: This is intended ONLY for debugging score_mod and mask_mod functions. +# When enabled, this bypasses the required internal compilation that ensures correctness +# and performance. Only use this temporarily when you need to set breakpoints +# in your score_mod/mask_mod functions during development. +# +# This flag only affects the internal compilation when flex_attention is called directly. +# If you have already wrapped flex_attention in torch.compile(), this flag has no effect +# and the user's compilation will still occur. +# +# Usage: +# import torch.nn.attention.flex_attention as fa +# fa._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True +# # Now you can set breakpoints in your score_mod/mask_mod +# output = fa.flex_attention(q, k, v, score_mod=my_score_mod) +# +_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = False + +_WARNINGS_SHOWN: set[str] = set() + + +def _warn_once( + warning_id: str, message: str, category: type[Warning] = UserWarning +) -> None: + """Helper to ensure each warning is shown only once per process.""" + if warning_id not in _WARNINGS_SHOWN: + warnings.warn(message, category, stacklevel=2) + _WARNINGS_SHOWN.add(warning_id) + + __all__ = [ "BlockMask", "flex_attention", + "FlexKernelOptions", "create_block_mask", "create_mask", "create_nested_block_mask", @@ -39,6 +82,123 @@ _mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] +class FlexKernelOptions(TypedDict, total=False): + """Options for controlling the behavior of FlexAttention kernels. + + These options are passed to the underlying Triton kernels to control performance + and numerical behavior. Most users will not need to specify these options as the + default autotuning provides good performance. + + The options can be prefixed with 'fwd_' or 'bwd_' to apply only to forward or + backward pass respectively. For example: 'fwd_BLOCK_M' and 'bwd_BLOCK_M1'. + + Note: + We currently do not provide any backward compatibility guarantees for these options. + That being said most of these have remained pretty stable since their introduction. But + We do not consider this part of the public API just yet. We think that some documentation + Is better than secret hidden flags, but we may change these options in the future. + + Example Usage: + .. code-block:: python + + # Using dictionary (backward compatible) + kernel_opts = {"BLOCK_M": 64, "BLOCK_N": 64, "PRESCALE_QK": True} + output = flex_attention(q, k, v, kernel_options=kernel_opts) + + # Using TypedDict (recommended for type safety) + from torch.nn.attention.flex_attention import FlexKernelOptions + + kernel_opts: FlexKernelOptions = { + "BLOCK_M": 64, + "BLOCK_N": 64, + "PRESCALE_QK": True, + } + output = flex_attention(q, k, v, kernel_options=kernel_opts) + + # Forward/backward specific options + kernel_opts: FlexKernelOptions = { + "fwd_BLOCK_M": 64, + "bwd_BLOCK_M1": 32, + "PRESCALE_QK": False, + } + output = flex_attention(q, k, v, kernel_options=kernel_opts) + """ + + # Performance tuning options + num_warps: NotRequired[int] + """Number of warps to use in the CUDA kernel. Higher values may improve performance + but increase register pressure. Default is determined by autotuning.""" + + num_stages: NotRequired[int] + """Number of pipeline stages in the CUDA kernel. Higher values may improve performance + but increase shared memory usage. Default is determined by autotuning.""" + + BLOCK_M: NotRequired[int] + """Thread block size for the sequence length dimension of Q in forward pass. + Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning.""" + + BLOCK_N: NotRequired[int] + """Thread block size for the sequence length dimension of K/V in forward pass. + Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning.""" + + # Backward-specific block sizes (when prefixed with 'bwd_') + BLOCK_M1: NotRequired[int] + """Thread block size for Q dimension in backward pass. Use as 'bwd_BLOCK_M1'. + Default is determined by autotuning.""" + + BLOCK_N1: NotRequired[int] + """Thread block size for K/V dimension in backward pass. Use as 'bwd_BLOCK_N1'. + Default is determined by autotuning.""" + + BLOCK_M2: NotRequired[int] + """Thread block size for second Q dimension in backward pass. Use as 'bwd_BLOCK_M2'. + Default is determined by autotuning.""" + + BLOCK_N2: NotRequired[int] + """Thread block size for second K/V dimension in backward pass. Use as 'bwd_BLOCK_N2'. + Default is determined by autotuning.""" + + PRESCALE_QK: NotRequired[bool] + """Whether to pre-scale QK by 1/sqrt(d) and change of base. This is slightly faster but + may have more numerical error. Default: False.""" + + ROWS_GUARANTEED_SAFE: NotRequired[bool] + """If True, guarantees that at least one value in each row is not masked out. + Allows skipping safety checks for better performance. Only set this if you are certain + your mask guarantees this property. For example, causal attention is guaranteed safe + because each query has at least 1 key-value to attend to. Default: False.""" + + BLOCKS_ARE_CONTIGUOUS: NotRequired[bool] + """If True, guarantees that all blocks in the mask are contiguous. + Allows optimizing block traversal. For example, causal masks would satisfy this, + but prefix_lm + sliding window would not. Default: False.""" + + WRITE_DQ: NotRequired[bool] + """Controls whether gradient scatters are done in the DQ iteration loop of the backward pass. + Setting this to False will force this to happen in the DK loop which depending on your + specific score_mod and mask_mod might be faster. Default: True.""" + + FORCE_USE_FLEX_ATTENTION: NotRequired[bool] + """If True, forces the use of the flex attention kernel instead of potentially using + the more optimized flex-decoding kernel for short sequences. This can be a helpful + option for debugging. Default: False.""" + + USE_TMA: NotRequired[bool] + """Whether to use Tensor Memory Accelerator (TMA) on supported hardware. + This is experimental and may not work on all hardware, currently specific + to NVIDIA GPUs Hopper+. Default: False.""" + + # ROCm-specific options + kpack: NotRequired[int] + """ROCm-specific kernel packing parameter.""" + + matrix_instr_nonkdim: NotRequired[int] + """ROCm-specific matrix instruction non-K dimension.""" + + waves_per_eu: NotRequired[int] + """ROCm-specific waves per execution unit.""" + + class _ModificationType(Enum): """Enum for the type of modification function. - SCORE_MOD: score_mod function which accepts a score as the first argument @@ -1244,7 +1404,7 @@ def flex_attention( scale: Optional[float] = None, enable_gqa: bool = False, return_lse: bool = False, - kernel_options: Optional[dict[str, Any]] = None, + kernel_options: Optional[FlexKernelOptions] = None, ) -> Union[Tensor, tuple[Tensor, Tensor]]: r"""This function implements scaled dot product attention with an arbitrary attention score modification function. @@ -1280,7 +1440,9 @@ def score_mod( scale (Optional[float]): Scaling factor applied prior to softmax. If none, the default value is set to :math:`\frac{1}{\sqrt{E}}`. enable_gqa (bool): If set to True, enables Grouped Query Attention (GQA) and broadcasts key/value heads to query heads. return_lse (bool): Whether to return the logsumexp of the attention scores. Default is False. - kernel_options (Optional[Dict[str, Any]]): Options to pass into the Triton kernels. + kernel_options (Optional[FlexKernelOptions]): + Options to control the behavior of the underlying Triton kernels. + See :class:`FlexKernelOptions` for available options and usage examples. Returns: output (Tensor): Attention output; shape :math:`(B, Hq, L, Ev)`. @@ -1416,6 +1578,18 @@ def score_mod( else: return out + if not _FLEX_ATTENTION_DISABLE_COMPILE_DEBUG: + _warn_once( + warning_id="flex_attention_performance", + message=( + "flex_attention called without torch.compile() - this will use an unfused implementation that materializes the full scores matrix instead of generating a fused kernel.\n\n" + "SOLUTION: Use torch.compile(flex_attention)(...)\n\n" + "If you want to debug your score_mod/mask_mod, you can set:\n" + "torch.nn.attention.flex_attention._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True\n\n" + "This will allow you to use print statements or breakpoints. Note: This doesn't work with the backwards pass and may produce incorrect results." + ), + ) + if not torch._dynamo.is_dynamo_supported(): raise RuntimeError("flex_attention requires dynamo support") @@ -1438,9 +1612,15 @@ def _flex_attention_hop_wrapper(*args, **kwargs): ) else: backend = "eager" - out, lse = torch.compile( - _flex_attention_hop_wrapper, backend=backend, fullgraph=True - )( + + if _FLEX_ATTENTION_DISABLE_COMPILE_DEBUG: + flex_fn = _flex_attention_hop_wrapper + else: + flex_fn = torch.compile( + _flex_attention_hop_wrapper, backend=backend, fullgraph=True + ) + + out, lse = flex_fn( query, key, value, @@ -1449,7 +1629,7 @@ def _flex_attention_hop_wrapper(*args, **kwargs): scale, kernel_options, ) - if return_lse: - return out, lse * math.log(2) - else: - return out + if return_lse: + return out, lse * math.log(2) + else: + return out diff --git a/torch/onnx/README.md b/torch/onnx/README.md index c4691ea01802..7c8596365f27 100644 --- a/torch/onnx/README.md +++ b/torch/onnx/README.md @@ -23,7 +23,7 @@ symbolic_opset9.py. To extend support for updated operators in different opset versions on top of opset 9, simply add the updated symbolic functions in the respective symbolic_opset{version}.py file. -Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example. +Check out topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example. ## Editing Symbolic Files diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 410b34b042cf..6c301ef294eb 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -38,8 +38,6 @@ "OnnxExporterError", "ONNXProgram", "enable_fake_mode", - # DORT / torch.compile - "is_onnxrt_backend_supported", ] from typing import Any, Callable, TYPE_CHECKING @@ -51,12 +49,6 @@ from ._internal._exporter_legacy import enable_fake_mode from ._internal.exporter._onnx_program import ONNXProgram -from ._internal.onnxruntime import ( - is_onnxrt_backend_supported, - OrtBackend as _OrtBackend, - OrtBackendOptions as _OrtBackendOptions, - OrtExecutionProvider as _OrtExecutionProvider, -) from ._type_utils import JitScalarType from .errors import OnnxExporterError from .utils import ( @@ -98,11 +90,7 @@ JitScalarType.__module__ = "torch.onnx" ONNXProgram.__module__ = "torch.onnx" OnnxExporterError.__module__ = "torch.onnx" -_OrtBackend.__module__ = "torch.onnx" -_OrtBackendOptions.__module__ = "torch.onnx" -_OrtExecutionProvider.__module__ = "torch.onnx" enable_fake_mode.__module__ = "torch.onnx" -is_onnxrt_backend_supported.__module__ = "torch.onnx" producer_name = "pytorch" producer_version = _C_onnx.PRODUCER_VERSION diff --git a/torch/onnx/_constants.py b/torch/onnx/_constants.py index b3c386b701d9..87ff04da8cd1 100644 --- a/torch/onnx/_constants.py +++ b/torch/onnx/_constants.py @@ -6,7 +6,7 @@ ONNX_MIN_OPSET = 7 ONNX_MAX_OPSET = 23 ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET = 20 -ONNX_DEFAULT_OPSET = 18 +ONNX_DEFAULT_OPSET = 20 ONNX_CONSTANT_FOLDING_MIN_OPSET = 9 PYTORCH_GITHUB_ISSUES_URL = "https://github.com/pytorch/pytorch/issues" diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index b3150ef9cdeb..f9ae42b26b84 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -3,32 +3,18 @@ __all__ = [ - "ExportOptions", - "ONNXRuntimeOptions", - "OnnxRegistry", "enable_fake_mode", ] -import abc import contextlib import dataclasses import logging -import warnings -from collections import defaultdict -from typing import Any, Callable, TYPE_CHECKING -from typing_extensions import deprecated +from typing import Any, TYPE_CHECKING import torch import torch._ops -from torch.onnx._internal import io_adapter -from torch.onnx._internal._lazy_import import onnxscript_apis -from torch.onnx._internal.exporter import _constants -from torch.onnx._internal.fx import ( - decomposition_table, - patcher as patcher, - registration, -) +from torch.onnx._internal.fx import patcher as patcher # We can only import onnx from this module in a type-checking context to ensure that @@ -36,10 +22,6 @@ # 'import onnx' inside of dynamo_export (by way of _assert_dependencies). if TYPE_CHECKING: import io - from collections.abc import Mapping, Sequence - - import onnxruntime - import onnxscript from torch._subclasses import fake_tensor @@ -62,219 +44,6 @@ class ONNXFakeContext: """List of paths of files that contain the model :meth:`state_dict`""" -@deprecated( - "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.", -) -class OnnxRegistry: - """Registry for ONNX functions. - - .. deprecated:: 2.7 - Please use ``torch.onnx.export(..., dynamo=True)`` instead. - - The registry maintains a mapping from qualified names to symbolic functions under a - fixed opset version. It supports registering custom onnx-script functions and for - dispatcher to dispatch calls to the appropriate function. - - """ - - def __init__(self) -> None: - """Initializes the registry""" - - # NOTE: _registry is the registry maps OpNameto a list of ONNXFunctions. It is important - # not to directly modify this variable. Instead, access to it should be done through - # the public methods: register_custom_op, get_ops, and is_registered_op. - self._registry: dict[registration.OpName, list[registration.ONNXFunction]] = ( - defaultdict(list) - ) - - self._opset_version = _constants.TORCHLIB_OPSET - warnings.warn( - f"torch.onnx.dynamo_export only implements opset version {self._opset_version} for now. If you need to use a " - "different opset version, please register them with register_custom_op." - ) - - self._initiate_registry_from_torchlib() - - @property - def opset_version(self) -> int: - """The ONNX opset version the exporter should target.""" - - return self._opset_version - - def _initiate_registry_from_torchlib(self) -> None: - """Populates the registry with ATen functions from torchlib. - - Args: - torchlib_registry: The torchlib registry to use for populating the registry. - """ - for meta in onnxscript_apis.get_torchlib_ops(): - internal_name_instance = registration.OpName.from_qualified_name( - meta.qualified_name - ) - symbolic_function = registration.ONNXFunction( - onnx_function=meta.function, # type: ignore[arg-type] - op_full_name=internal_name_instance.qualified_name(), - is_custom=False, - is_complex=meta.is_complex, - ) - self._register(internal_name_instance, symbolic_function) - - def _register( - self, - internal_qualified_name: registration.OpName, - symbolic_function: registration.ONNXFunction, - ) -> None: - """Registers a ONNXFunction to an operator. - - Args: - internal_qualified_name: The qualified name of the operator to register: OpName. - symbolic_function: The ONNXFunction to register. - """ - self._registry[internal_qualified_name].append(symbolic_function) - - def register_op( - self, - function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction, - namespace: str, - op_name: str, - overload: str | None = None, - is_complex: bool = False, - ) -> None: - """Registers a custom operator: torch.ops.... - - Args: - function: The onnx-sctip function to register. - namespace: The namespace of the operator to register. - op_name: The name of the operator to register. - overload: The overload of the operator to register. If it's default overload, - leave it to None. - is_complex: Whether the function is a function that handles complex valued inputs. - - Raises: - ValueError: If the name is not in the form of 'namespace::op'. - """ - internal_name_instance = registration.OpName.from_name_parts( - namespace=namespace, op_name=op_name, overload=overload - ) - symbolic_function = registration.ONNXFunction( - onnx_function=function, - op_full_name=internal_name_instance.qualified_name(), - is_custom=True, - is_complex=is_complex, - ) - self._register(internal_name_instance, symbolic_function) - - def get_op_functions( - self, namespace: str, op_name: str, overload: str | None = None - ) -> list[registration.ONNXFunction] | None: - """Returns a list of ONNXFunctions for the given op: torch.ops.... - - The list is ordered by the time of registration. The custom operators should be - in the second half of the list. - - Args: - namespace: The namespace of the operator to get. - op_name: The name of the operator to get. - overload: The overload of the operator to get. If it's default overload, - leave it to None. - Returns: - A list of ONNXFunctions corresponding to the given name, or None if - the name is not in the registry. - """ - internal_name_instance = registration.OpName.from_name_parts( - namespace=namespace, op_name=op_name, overload=overload - ) - return self._registry.get(internal_name_instance) - - def is_registered_op( - self, namespace: str, op_name: str, overload: str | None = None - ) -> bool: - """Returns whether the given op is registered: torch.ops.... - - Args: - namespace: The namespace of the operator to check. - op_name: The name of the operator to check. - overload: The overload of the operator to check. If it's default overload, - leave it to None. - - Returns: - True if the given op is registered, otherwise False. - """ - functions = self.get_op_functions( - namespace=namespace, op_name=op_name, overload=overload - ) - return functions is not None - - def _all_registered_ops(self) -> set[str]: - """Returns the set of all registered function names.""" - return { - op_name_class.qualified_name() for op_name_class in self._registry.keys() - } - - -@deprecated( - "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.", - category=None, -) -class ExportOptions: - """Options to influence the TorchDynamo ONNX exporter. - - .. deprecated:: 2.7 - Please use ``torch.onnx.export(..., dynamo=True)`` instead. - - Attributes: - dynamic_shapes: Shape information hint for input/output tensors. - When ``None``, the exporter determines the most compatible setting. - When ``True``, all input shapes are considered dynamic. - When ``False``, all input shapes are considered static. - fake_context: The fake context used for symbolic tracing. - onnx_registry: The ONNX registry used to register ATen operators to ONNX functions. - """ - - def __init__( - self, - *, - dynamic_shapes: bool | None = True, - fake_context: ONNXFakeContext | None = None, - onnx_registry: OnnxRegistry | None = None, - ): - self.dynamic_shapes = dynamic_shapes - self.fake_context = fake_context - self.onnx_registry = onnx_registry - - -@deprecated( - "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.", - category=None, -) -class ResolvedExportOptions(ExportOptions): - """Consolidates :class:`ExportOptions` with default values. - All unspecified options from :class:`ExportOptions` are assigned a default value. - This is an internal class and its API may be changed at any time without notice. - """ - - def __init__(self): - from torch.onnx._internal.fx import ( - dynamo_graph_extractor, - onnxfunction_dispatcher, - ) - - self.dynamic_shapes: bool = True - self.fx_tracer: dynamo_graph_extractor.DynamoExport = ( - dynamo_graph_extractor.DynamoExport() - ) - self.fake_context = None - self.onnx_registry: OnnxRegistry = OnnxRegistry() - self.decomposition_table = ( - decomposition_table.create_onnx_friendly_decomposition_table( # type: ignore[assignment] - self.onnx_registry - ) - ) - self.onnxfunction_dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher( - self.onnx_registry, - ) - - @contextlib.contextmanager def enable_fake_mode(): """Enable fake mode for the duration of the context. @@ -347,150 +116,3 @@ def enable_fake_mode(): fake_context.state_dict_paths = tuple( patcher_context.paths, ) # type: ignore[assignment] - - -@deprecated( - "torch.onnx.dynamo_export is deprecated since 2.7.0. Please use torch.onnx.export(..., dynamo=True) instead.", -) -class ONNXRuntimeOptions: - """Options to influence the execution of the ONNX model through ONNX Runtime. - - .. deprecated:: 2.7 - Please use ``torch.onnx.export(..., dynamo=True)`` instead. - - Attributes: - session_options: ONNX Runtime session options. - execution_providers: ONNX Runtime execution providers to use during model execution. - execution_provider_options: ONNX Runtime execution provider options. - """ - - session_options: Sequence[onnxruntime.SessionOptions] | None = None - """ONNX Runtime session options.""" - - execution_providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None - """ONNX Runtime execution providers to use during model execution.""" - - execution_provider_options: Sequence[dict[Any, Any]] | None = None - """ONNX Runtime execution provider options.""" - - def __init__( - self, - *, - session_options: Sequence[onnxruntime.SessionOptions] | None = None, - execution_providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None, - execution_provider_options: Sequence[dict[Any, Any]] | None = None, - ): - self.session_options = session_options - self.execution_providers = execution_providers - self.execution_provider_options = execution_provider_options - - -class FXGraphExtractor(abc.ABC): - """Abstract interface for FX graph extractor engines. - This class isolates FX extraction logic from the rest of the export logic. - That allows a single ONNX exporter that can leverage different FX graphs.""" - - def __init__(self) -> None: - super().__init__() - self.input_adapter: io_adapter.InputAdapter = io_adapter.InputAdapter() - self.output_adapter: io_adapter.OutputAdapter = io_adapter.OutputAdapter() - - @abc.abstractmethod - def generate_fx( - self, - options: ResolvedExportOptions, - model: torch.nn.Module | Callable, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - ) -> torch.fx.GraphModule: - """Analyzes user ``model`` and generates a FX graph. - Args: - options: The export options. - model: The user model. - model_args: The model's positional input arguments. - model_kwargs: The model's keyword input arguments. - Returns: - The generated FX Graph. - """ - ... - - # TODO: Design the passes API - @abc.abstractmethod - def pre_export_passes( - self, - options: ResolvedExportOptions, - original_model: torch.nn.Module | Callable, - fx_module: torch.fx.GraphModule, - fx_module_args: Sequence[Any], - ): - """Applies pre-export passes to the FX graph. - - Pre-export passes are FX-to-FX graph transformations that make the graph - more palatable for the FX-to-ONNX conversion. - For example, it can be used to flatten model input/output, add explicit - casts to the graph, replace/decompose operators, functionalize the graph, etc. - """ - ... - - -def common_pre_export_passes( - options: ResolvedExportOptions, - original_model: torch.nn.Module | Callable, - fx_module: torch.fx.GraphModule, - fx_module_args: Sequence[Any], -): - # TODO: Import here to prevent circular dependency - from torch.onnx._internal.fx import passes - - # Apply decomposition table to the input graph. - module = passes.Decompose( - fx_module, - options.decomposition_table, # type: ignore[arg-type] - enable_dynamic_axes=options.dynamic_shapes, - allow_fake_constant=options.fake_context is not None, - ).run(*fx_module_args) - - # ONNX does not support views and mutations. - # Functionalize to get a semantically equivalent graph without mutations. - module = passes.Functionalize( - module, - enable_dynamic_axes=options.dynamic_shapes, - allow_fake_constant=options.fake_context is not None, - ).run(*fx_module_args) - - # Input mutations are detected and distilled after `Functionalize` pass. - # Remove them since ONNX inference does not need them. - module = passes.RemoveInputMutation(module).run(*fx_module_args) - - # ONNX does not support concept of (implicit) type promotion. - # Insert type casts explicitly where needed. - module = passes.InsertTypePromotion(module).run() - - if isinstance(original_model, torch.nn.Module): - module = passes.RestoreParameterAndBufferNames(module, original_model).run() - - # ONNX does not support None inputs. During graph building, all None inputs - # are removed. Here we register this step to input adapter. - options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNoneInputStep()) - - # NOTE: temp workaround for https://github.com/pytorch/pytorch/issues/99534 - # Dynamo doesn't support non-tensor inputs. - options.fx_tracer.input_adapter.append_step(io_adapter.RemoveNonTensorInputStep()) - - # ONNX does not support complex inputs. During graph building, all complex inputs - # are converted to real representation inputs. Here we register this step to - # input/output adapter. - options.fx_tracer.input_adapter.append_step( - io_adapter.ConvertComplexToRealRepresentationInputStep() - ) - - # ONNX can't represent collection types (e.g., dictionary, tuple of tuple of - # tensor, etc), we flatten the collection and register each element as output. - options.fx_tracer.output_adapter.append_step(io_adapter.FlattenOutputStep()) - - # Output post-processing steps should happen after `FlattenOutputStep`. - options.fx_tracer.output_adapter.append_step( - io_adapter.ConvertComplexToRealRepresentationOutputStep() - ) - - return module diff --git a/torch/onnx/_internal/_lazy_import.py b/torch/onnx/_internal/_lazy_import.py index 3557ef099309..7cde0bd35177 100644 --- a/torch/onnx/_internal/_lazy_import.py +++ b/torch/onnx/_internal/_lazy_import.py @@ -28,7 +28,7 @@ def __getattr__(self, attr: str) -> object: # NOTE: Add additional used imports here. if TYPE_CHECKING: import onnx - import onnx_ir # type: ignore[import-untyped] + import onnx_ir # type: ignore[import-untyped, import-not-found] import onnxscript import onnxscript._framework_apis.torch_2_8 as onnxscript_apis diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py index c3a0f26b227d..cf83aa406154 100644 --- a/torch/onnx/_internal/exporter/_compat.py +++ b/torch/onnx/_internal/exporter/_compat.py @@ -10,9 +10,9 @@ from typing import Any, Callable, TYPE_CHECKING import torch +from torch.onnx import _constants as onnx_constants from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir from torch.onnx._internal.exporter import ( - _constants, _core, _dynamic_shapes, _onnx_program, @@ -50,7 +50,7 @@ def export_compat( verbose: bool | None = None, input_names: Sequence[str] | None = None, output_names: Sequence[str] | None = None, - opset_version: int | None = _constants.TORCHLIB_OPSET, + opset_version: int | None = onnx_constants.ONNX_DEFAULT_OPSET, custom_translation_table: dict[Callable, Callable | Sequence[Callable]] | None = None, dynamic_axes: Mapping[str, Mapping[int, str]] @@ -70,7 +70,7 @@ def export_compat( legacy_export_kwargs: dict[str, Any] | None = None, ) -> _onnx_program.ONNXProgram: if opset_version is None: - opset_version = _constants.TORCHLIB_OPSET + opset_version = onnx_constants.ONNX_DEFAULT_OPSET if isinstance(model, torch.export.ExportedProgram): # We know the model is already exported program, so the args, kwargs, and dynamic_shapes diff --git a/torch/onnx/_internal/exporter/_fx_passes.py b/torch/onnx/_internal/exporter/_fx_passes.py index a14b25d7cda1..98359f2ebaff 100644 --- a/torch/onnx/_internal/exporter/_fx_passes.py +++ b/torch/onnx/_internal/exporter/_fx_passes.py @@ -37,8 +37,9 @@ def remove_assertion_nodes(graph_module: torch.fx.GraphModule) -> torch.fx.Graph torch.ops.aten._assert_scalar.default, torch.ops.aten._assert_tensor_metadata.default, } - for node in graph_module.graph.nodes: - if node.op == "call_function" and node.target in aten_assertion_targets: - graph_module.graph.erase_node(node) - graph_module.recompile() + for gm in graph_module.modules(): + for node in gm.graph.nodes: # type: ignore[union-attr] + if node.op == "call_function" and node.target in aten_assertion_targets: + gm.graph.erase_node(node) # type: ignore[operator, union-attr] + gm.recompile() # type: ignore[operator] return graph_module diff --git a/torch/onnx/_internal/fx/decomposition_table.py b/torch/onnx/_internal/fx/decomposition_table.py deleted file mode 100644 index 71715e1ad234..000000000000 --- a/torch/onnx/_internal/fx/decomposition_table.py +++ /dev/null @@ -1,116 +0,0 @@ -# mypy: allow-untyped-defs -"""Dispatcher for AtenLib functions from onnx-script.""" - -from __future__ import annotations - -from typing import Callable - -import torch -import torch._ops -import torch.fx -from torch.onnx._internal.fx import registration - - -def _create_onnx_supports_op_overload_table( - registry, -) -> set[torch._ops.OperatorBase | Callable]: - """ - Creates a set of OperatorBase and Callable objects that represent ONNX-supported PyTorch operations. - - Args: - registry (OnnxRegistry): The ONNX registry for PyTorch. - - Returns: - A collection of OperatorBase and Callable objects representing ONNX-supported PyTorch operations. - """ - table: set[torch._ops.OperatorBase | Callable] = set() - - # Some ops in `torch.ops.aten` are not discoverable through `dir(torch.ops.aten)`, - # but retrievable via explicit lookup. - # https://github.com/pytorch/pytorch/issues/99681 - # This is a workaround to make sure we register ONNX symbolic functions for these. - onnx_supported_aten_lookup_table = [ - k.split("::")[1].split(".")[0] - for k in registry._all_registered_ops() - if k.startswith("aten::") - ] - - for op_namespace in (torch.ops.aten, torch.ops.prims): - attr_names = dir(op_namespace) - if op_namespace is torch.ops.aten: - attr_names += onnx_supported_aten_lookup_table - for attr_name in attr_names: - if not hasattr(op_namespace, attr_name): - # torchlib owns some attributes that are not aten ops. - continue - op_overload_packet = getattr(op_namespace, attr_name) - if not isinstance(op_overload_packet, torch._ops.OpOverloadPacket): - continue - - for overload_name in op_overload_packet.overloads(): - op_overload = getattr(op_overload_packet, overload_name) - internal_op_name = registration.OpName.from_qualified_name( - qualified_name=op_overload.name() - ) - # NOTE: If the overload is supported in registry or it's default overload is supported in registry, - # we add it to the table. - if registry.is_registered_op( - namespace=internal_op_name.namespace, - op_name=internal_op_name.op_name, - overload=internal_op_name.overload, - ) or registry.is_registered_op( - namespace=internal_op_name.namespace, - op_name=internal_op_name.op_name, - overload=None, - ): - # This line maps torch.ops.aten.add.Tensor, torch.ops.aten.add.Scalar, torch.ops.aten.add.out, etc - # to "aten::add". This means the exporter for "aten::add" is used for all overloads of "aten::add". - # This is applied to all ops under torch.ops.aten. - table.add(op_overload) - return table - - -def create_onnx_friendly_decomposition_table( - registry, -) -> dict[torch._ops.OperatorBase, Callable]: - """ - This function creates a dictionary of op overloads and their decomposition functions - for ops that do not have ONNX symbolic functions. If an op already has an ONNX symbolic function, - its decomposition function is excluded from the table. The decomposition table is a subset of PyTorch's - built-in aten-to-aten decomposition. - - Args: - registry: The ONNX registry for PyTorch. - - Returns: - Dict[torch._ops.OperatorBase, Callable]: A dictionary that maps op overloads to their corresponding - decomposition functions. - """ - decomposition_table: dict[torch._ops.OperatorBase, Callable] = {} - # Dictionary that maps torch.ops.aten.* to exporter look up key; e.g., - # _OP_OVERLOAD_TO_EXPORTER_KEY_TABLE[torch.add.Tensor] is "aten::add". - _ONNX_SUPPORT_OP_OVERLOADS = _create_onnx_supports_op_overload_table(registry) - - # NOTE: If we import torch._decomp, we will get RuntimeError: Only a single - # TORCH_LIBRARY can be used to register the namespace nvprims; please put all of your - # definitions in a single TORCH_LIBRARY block. - for op_overload, decomp_fn in torch._decomp.decomposition_table.items(): - # Skip decomposition into "prim::*" ops (defined in 'torch._refs'), because they - # are not generally supported by ONNX. - # Skip decomposition for op_overload as long as that op_overload has a corresponding ONNX - # symbolic function. - if ( - "torch._refs" in decomp_fn.__module__ - or op_overload in _ONNX_SUPPORT_OP_OVERLOADS - ): - continue - decomposition_table[op_overload] = decomp_fn - - # NOTE: There are ops in core ATen and under torch._refs, - # that are not decomposed to prim::ops. We need to pick them - # back - for op_overload, decomp_fn in torch._decomp.core_aten_decompositions().items(): - if op_overload in _ONNX_SUPPORT_OP_OVERLOADS: - continue - decomposition_table[op_overload] = decomp_fn - return decomposition_table diff --git a/torch/onnx/_internal/fx/dynamo_graph_extractor.py b/torch/onnx/_internal/fx/dynamo_graph_extractor.py deleted file mode 100644 index b11903619c08..000000000000 --- a/torch/onnx/_internal/fx/dynamo_graph_extractor.py +++ /dev/null @@ -1,232 +0,0 @@ -# mypy: allow-untyped-defs -# NOTE: This file is referenced by name at -# /opt/pytorch/torch/_dynamo/eval_frame.py::DONT_WRAP_FILES. -# introduced by https://github.com/pytorch/pytorch/pull/98894. -# If this file is renamed, moved, etc please update the reference there! - -from __future__ import annotations - -import contextlib -import functools -import inspect -from typing import Any, Callable, TYPE_CHECKING - -import torch._dynamo -import torch.export as torch_export -import torch.fx -import torch.onnx -from torch.onnx._internal import _exporter_legacy, io_adapter -from torch.utils import _pytree as pytree - - -if TYPE_CHECKING: - from collections.abc import Mapping, Sequence - - -class _PyTreeExtensionContext: - """Context manager to register PyTree extension.""" - - _extensions: dict[type, tuple[pytree.FlattenFunc, pytree.UnflattenFunc]] - - def __init__(self) -> None: - self._extensions = {} - # Register PyTree extension for HuggingFace model output. - self._register_huggingface_model_output_extension() - - def __enter__(self): - for class_type, (flatten_func, unflatten_func) in self._extensions.items(): - pytree._private_register_pytree_node( - class_type, - flatten_func, - unflatten_func, - ) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - for class_type in self._extensions: - pytree.SUPPORTED_NODES.pop(class_type) - - def register_pytree_node( - self, - class_type: type, - flatten_func: pytree.FlattenFunc, - unflatten_func: pytree.UnflattenFunc, - ): - """Register PyTree extension for a custom python type. - - Args: - class_type: The custom python type. - flatten_func: The flatten function. - unflatten_func: The unflatten function. - - Raises: - AssertionError: If the custom python type is already registered. - """ - if class_type in pytree.SUPPORTED_NODES or class_type in self._extensions: - # PyTree node already registered. - # E.g., `huggingface/transformer` registers `ModelOutput` as PyTree node after - # https://github.com/huggingface/transformers/pull/25358. - return - self._extensions[class_type] = (flatten_func, unflatten_func) - - def _register_huggingface_model_output_extension(self): - try: - from transformers import modeling_outputs # type: ignore[import] - except ImportError: - return - - def model_output_flatten( - output: modeling_outputs.ModelOutput, - ) -> tuple[list[Any], pytree.Context]: - return list(output.values()), (type(output), list(output.keys())) - - def model_output_unflatten( - values: list[Any], context: pytree.Context - ) -> modeling_outputs.ModelOutput: - output_type, keys = context - return output_type(**dict(zip(keys, values))) - - # All 'ModelOutput' subclasses are defined under module 'modeling_outputs'. - named_model_output_classes = inspect.getmembers( - modeling_outputs, - lambda x: ( - inspect.isclass(x) - and issubclass(x, modeling_outputs.ModelOutput) - and x is not modeling_outputs.ModelOutput - ), - ) - - for _, class_type in named_model_output_classes: - self.register_pytree_node( - class_type, - model_output_flatten, - model_output_unflatten, # type: ignore[arg-type ] - ) - - -class DynamoFlattenOutputStep(io_adapter.FlattenOutputStep): - """Flatten nested collection and custom python types and return a flat list of elements. - - Extended from :class:`io_adapter.FlattenOutputStep` to support flattening arbitrary - types via pytree extension. By default this supports many common user defined python - types such as :class:`ModelOutput` from HuggingFace transformers. - - The pytree extension can be customized by passing in a ``_PyTreeExtensionContext`` - object. See :meth:`_PyTreeExtensionContext.register_pytree_node`. - """ - - def __init__(self, pytree_extension_context: _PyTreeExtensionContext | None = None): - super().__init__() - self._pytree_extension_context = ( - pytree_extension_context or _PyTreeExtensionContext() - ) - - def apply( - self, - model_outputs: Any, - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> Sequence[Any]: - """Flatten the model outputs, under the context of pytree extension.""" - with self._pytree_extension_context: - return super().apply(model_outputs, model=model) - - -def _wrap_model_with_output_adapter( - model: torch.nn.Module | Callable, - output_adapter: DynamoFlattenOutputStep, -) -> Callable: - """Wrap model with output adapter. - - This is a helper function to enable :func:`dynamo.export` on models that produce - custom user defined types outputs. It wraps the model with an output adapter to - convert the outputs to :func:`dynamo.export` compatible types, i.e. :class:`torch.Tensor`. - - The adapting logic is controlled by ``output_adapter``. - - Args: - model: PyTorch model or function. - output_adapter: Output adapter to apply to model output. - Returns: - Wrapped model. - """ - model_func = model.forward if isinstance(model, torch.nn.Module) else model - - # Preserve original function signature. - @functools.wraps(model_func) - def wrapped(*args, **kwargs): - return output_adapter.apply(model_func(*args, **kwargs), model=model) - - return wrapped - - -class DynamoExport(_exporter_legacy.FXGraphExtractor): - """Generates a FX GraphModule using torch.dynamo.export API - Args: - aten_graph: If True, exports a graph with ATen operators. - If False, exports a graph with Python operators. - """ - - def __init__( - self, - aten_graph: bool | None = None, - ): - super().__init__() - self.aten_graph = aten_graph or True - - def generate_fx( - self, - options: _exporter_legacy.ResolvedExportOptions, - model: torch.nn.Module | Callable, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - ) -> torch.fx.GraphModule: - # `dynamo.export` does not recognize custom user defined classes as output type. - # Apply wrapper to adapt the outputs back to `dynamo.export` compatible types, - # i.e. :class:`torch.Tensor`. - dynamo_flatten_output_step = DynamoFlattenOutputStep() - wrapped_model = _wrap_model_with_output_adapter( - model, dynamo_flatten_output_step - ) - # Record the output adapter step. - self.output_adapter.append_step(dynamo_flatten_output_step) - - # Translate callable to FX graph. - # - fake_mode = ( - options.fake_context.fake_mode - if options.fake_context - else contextlib.nullcontext() - ) - fx_mode = "symbolic" if options.dynamic_shapes else "fake" - with fake_mode: # type: ignore[attr-defined] - graph_module, graph_guard = torch._dynamo.export( - wrapped_model, - tracing_mode=fx_mode, - )( - *model_args, - **model_kwargs, - ) - del graph_guard # Unused - torch._dynamo.reset() - - # Export FX graph to ONNX ModelProto. - self.input_adapter.append_step( - io_adapter.FlattenInputWithTreeSpecValidationInputStep() - ) - - updated_model_args = self.input_adapter.apply( - *model_args, model=model, **model_kwargs - ) - - return self.pre_export_passes(options, model, graph_module, updated_model_args) # type: ignore[return-value] - - def pre_export_passes( - self, - options: _exporter_legacy.ResolvedExportOptions, - original_model: torch.nn.Module | Callable, - fx_module: torch.fx.GraphModule, - fx_module_args: Sequence[Any], - ): - return _exporter_legacy.common_pre_export_passes( - options, original_model, fx_module, fx_module_args - ) diff --git a/torch/onnx/_internal/fx/fx_onnx_interpreter.py b/torch/onnx/_internal/fx/fx_onnx_interpreter.py deleted file mode 100644 index 424f2d171b97..000000000000 --- a/torch/onnx/_internal/fx/fx_onnx_interpreter.py +++ /dev/null @@ -1,718 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -import inspect -import operator -from typing import Callable, TYPE_CHECKING - -import onnxscript -from onnxscript.function_libs.torch_lib import ( - graph_building as onnxscript_graph_building, -) - -import torch -import torch.fx -from torch.onnx import _type_utils as jit_type_utils -from torch.onnx._internal.fx import ( - _pass, - onnxfunction_dispatcher, - type_utils as fx_type_utils, -) -from torch.utils import _pytree - - -if TYPE_CHECKING: - from collections.abc import Sequence - - -def _fx_node_to_onnx_message_formatter( - fn: Callable, - self, - node: torch.fx.Node, - *args, - **kwargs, -) -> str: - return f"FX Node: {node.op}:{node.target}[name={node.name}]. " - - -def _fx_graph_to_onnx_message_formatter( - fn: Callable, - self, - fx_graph_module: torch.fx.GraphModule, - *args, - **kwargs, -) -> str: - return f"FX Graph: {fx_graph_module._get_name()}. " - - -def _retrieve_or_adapt_input_to_graph_set( - fx_node_arg: fx_type_utils.Argument, - fx_name_to_onnxscript_value: dict[ - str, - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - ], - tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, -): - """Map FX value to TorchScript value. - - When creating TorchScript graph from FX graph, we need a mapping from FX variable - to TorchScript variable. This function maps FX variable, fx_node_arg, to torch.jit.Value. - """ - from onnxscript import opset18 as op - - onnx_tensor = fx_node_arg - if isinstance(onnx_tensor, torch.fx.Node): - # 1. fx_node_arg is a torch.fx.Node, which means - # fx_node_arg stands for the output of that torch.fx.Node. - # 2. fx_node_arg (variable in torch.fx.Graph) is be mapped to - # torch.jit.Value, fx_name_to_onnxscript_value[fx_node_arg.name], - # in TorchScript graph. - return fx_name_to_onnxscript_value[onnx_tensor.name] - elif isinstance(onnx_tensor, (tuple, list)) and any( - isinstance(node, torch.fx.Node) - and fx_type_utils.is_torch_symbolic_type(node.meta.get("val")) - for node in onnx_tensor - ): - # This intends to handle dynamic axes. for example, if the input size of op.Expand - # is dynamic, each dimension would be variable (i.e., sym variable in Pytorch - # FX graph. Note that sym variable is mapped to tensor in ONNX Script world) - # calculated by other operators. - sequence_mixed_elements: list[ - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...] - | list[int] - ] = [] - # onnx_tensor contains a list of scalars which could be one of - # - tensor with empty shape, - # - tensor with tensor with shape (1,), - # - torch.SymInt, - # - int - # - ... - # They should all be promoted to tensor with shape (1,) - # in order to call ONNX's Concat. - for tensor in onnx_tensor: - # Prepare `tensor` as input of ONNX's Concat. - - if isinstance( - tensor, torch.fx.Node - ) and fx_type_utils.is_torch_symbolic_type(tensor.meta.get("val")): - # In this case, tensor is a torch.SymInt from Dynamo's perspective. - # It might be mapped to tensor with shape () or (1,) in ONNX. - element_value = fx_name_to_onnxscript_value[tensor.name] - if isinstance( - element_value, onnxscript_graph_building.TorchScriptTensor - ): - # All elements sequence_mixed_elements will be send to onnx's Concat - # as inputs. Therefore, they are required to have the same rank. - # Since tensors with rank=0 (i.e., scalar) cannot be concated, all - # scalars are promoted to tensors with shape (1,). - with onnxscript.evaluator.default_as(tracer): - element_value = op.Reshape( - element_value, # type: ignore[arg-type, type-var] - [1], # type: ignore[arg-type, type-var] - ) - sequence_mixed_elements.append(element_value) - elif isinstance(tensor, int): - # NOTE: op.Concat doesn't support scalar, so we need to wrap it with - # dim, and onnx-script will promote it to tensor(int64) - sequence_mixed_elements.append([tensor]) - else: - raise RuntimeError( - f"Unsupported type in sequence_mixed_elements: {type(tensor)}" - ) - # Concat all the elements in the sequence. - # shapes are mapped to tensors in ONNX graph (TorchScriptGraph), - # so list of sym_ints is concatenated to a tensor before calling ONNX op. - - # For example: - # inputs: [[2], [4], fx.Node(SymIntA), [1], fx.Node(SymIntB)] - # outputs: op.Concat([op.Constant(2), op.Constant(4), TorchScriptTensor(A), op.Constant(1), TorchScriptTensor(B)]) - - # onnx-script auto wraps python number with op.Constants, - # so we don't need to specifically process them. - with onnxscript.evaluator.default_as(tracer): - output = op.Concat(*sequence_mixed_elements, axis=0) # type: ignore[type-var] - output.dtype = torch.int64 # type: ignore[union-attr] - output.shape = [len(sequence_mixed_elements)] # type: ignore[union-attr] - return output - elif isinstance(onnx_tensor, (tuple, list)) and all( - isinstance(node, torch.fx.Node) or node is None for node in onnx_tensor - ): - sequence_elements: list[ - onnxscript_graph_building.TorchScriptTensor - | None - | tuple[onnxscript_graph_building.TorchScriptTensor, ...] - ] = [] - for tensor in onnx_tensor: - sequence_elements.append( - fx_name_to_onnxscript_value[tensor.name] if tensor is not None else None # type: ignore[index, union-attr] - ) - return sequence_elements - if isinstance(onnx_tensor, torch.dtype): - onnx_tensor = int( # type: ignore[call-overload] - jit_type_utils.JitScalarType.from_dtype(onnx_tensor).onnx_type() - ) - # NOTE: if device is specified in kwargs (not consumed), it's free to ignored. But - # if it's in args, we need to set it to string for dispatcher to match schema. - if isinstance(onnx_tensor, torch.device): - # torch.device is not supported by onnxscript (no op). We turn it into - # a string. - return str(onnx_tensor) - # all other cases, we do nothing. - return onnx_tensor - - -def filter_incompatible_and_dtype_convert_kwargs(kwargs): - """Filter out kwargs that are not supported by onnxscript.""" - filtered = {} - for key, value in kwargs.items(): - if key in { - "layout", - "device", - "requires_grad", - "pin_memory", - "memory_format", - "implicit", - }: - continue - if key == "dtype": - if value is None: - # We omit if dtype is not provided, because onnxscript handles the - # default case. - continue - else: - value = int(jit_type_utils.JitScalarType.from_dtype(value).onnx_type()) # type: ignore[call-overload] - filtered[key] = value - return filtered - - -def _fill_tensor_shape_type( - onnxscript_values: onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - name: str, - expected_values: fx_type_utils.META_VALUE_TYPE - | list[fx_type_utils.META_VALUE_TYPE] - | tuple[fx_type_utils.META_VALUE_TYPE | None, ...], -): - """Fill the meta information of onnxscript_values with that from the fx FakeTensor.""" - - if isinstance(expected_values, (list, tuple)) and not isinstance( - onnxscript_values, (list, tuple) - ): - # ex: aten::split - in onnx_dtype: seq(tensor) - # onnxscript_values is a single tensor, but expected_values is a list of tensors. - return - - flat_onnxscript_values, _ = _pytree.tree_flatten(onnxscript_values) - flat_expected_values, _ = _pytree.tree_flatten(expected_values) - for i, (onnxscript_value, expected_value) in enumerate( - zip(flat_onnxscript_values, flat_expected_values) - ): - if expected_value is None: - # There is no shape/type from None. - # NOTE: according to https://github.com/pytorch/pytorch/blob/main/torch/_meta_registrations.py, - # None could be a valid value for return type, so we need to handle it. - # e.g. the function: meta__scaled_dot_product_flash() in cpu mode. - continue - elif fx_type_utils.is_torch_symbolic_type(expected_value): - # aten::sym_size output is a int, not a tensor, which stands - # for the size of one dim. We treat it as 1-D tensor. - onnxscript_value.dtype = fx_type_utils.from_sym_value_to_torch_dtype( - expected_value - ) - onnxscript_value.shape = torch.Size([1]) - elif isinstance(expected_value, (int, float, bool)): - onnxscript_value.dtype = fx_type_utils.from_scalar_type_to_torch_dtype( - type(expected_value) - ) - onnxscript_value.shape = torch.Size([]) - elif isinstance(expected_value, complex): - # From complex scalar to real representation - onnxscript_value_to_torch_dtype = ( - fx_type_utils.from_scalar_type_to_torch_dtype(type(expected_value)) - ) - onnxscript_value.dtype = ( - fx_type_utils.from_complex_to_float(onnxscript_value_to_torch_dtype) - if onnxscript_value_to_torch_dtype is not None - else None - ) - onnxscript_value.shape = torch.Size([2]) - elif fx_type_utils.is_torch_complex_dtype(expected_value.dtype): - # Like torch.view_as_real, we flatten complex tensors to real tensors with - # additional last dimension of 2 - onnxscript_value.shape = torch.Size((*expected_value.size(), 2)) - # complex64 -> float32, complex128 -> float64, etc. - onnxscript_value.dtype = fx_type_utils.from_complex_to_float( - expected_value.dtype - ) - # Dispatcher needs to know the value is complex - onnxscript_value.is_complex = True - else: - # We set node output sizes to be dynamic to continue the model conversion, - # and inputs are also set to be dynamic in add_input(). - onnxscript_value.shape = expected_value.size() - onnxscript_value.dtype = expected_value.dtype - - # naming - if i > 0: - onnxscript_value.name = f"{name}_{i}" - else: - onnxscript_value.name = name - - -def _fill_in_default_kwargs( - node: torch.fx.Node, -) -> tuple[list[fx_type_utils.Argument], dict[str, fx_type_utils.Argument]]: - """Find and Fill in the not provided kwargs with default values.""" - - # TODO: aten::sym_size has overload, but fx graph is using - # overloadpacket for some reasons. - # https://github.com/pytorch/pytorch/issues/97201 - # We manually assigned overload for aten::sym_size. - if hasattr(node.target, "_schema"): - node_schema = node.target._schema # type: ignore[union-attr] - else: - node_schema = torch.ops.aten.sym_size.int._schema # type: ignore[union-attr] - - # This function assumes the order of arguments in FX op is the - # same as the order of arguments in TorchScript op. - complete_args: list[fx_type_utils.Argument] = [] - complete_kwargs: dict[str, fx_type_utils.Argument] = {} - - if inspect.isbuiltin(node.target): - complete_args = list(node.args) - else: - for i, expected_arg in enumerate(node_schema.arguments): - if i < len(node.args): - complete_args.append(node.args[i]) - elif expected_arg.name in node.kwargs: - complete_kwargs[expected_arg.name] = node.kwargs[expected_arg.name] - else: - # Get default from schema. - complete_kwargs[expected_arg.name] = expected_arg.default_value - - return complete_args, complete_kwargs - - -def _wrap_fx_args_as_onnxscript_args( - complete_args: list[fx_type_utils.Argument], - complete_kwargs: dict[str, fx_type_utils.Argument], - fx_name_to_onnxscript_value: dict[ - str, - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - ], - tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, -) -> tuple[ - Sequence[ - onnxscript_graph_building.TorchScriptTensor - | str - | int - | float - | bool - | list - | complex - | None - ], - dict[str, fx_type_utils.Argument], -]: - """Map all FX arguments of a node to arguments in TorchScript graph.""" - - onnxscript_args = tuple( - _retrieve_or_adapt_input_to_graph_set(arg, fx_name_to_onnxscript_value, tracer) - for arg in complete_args - ) - onnxscript_kwargs = filter_incompatible_and_dtype_convert_kwargs(complete_kwargs) - - return onnxscript_args, onnxscript_kwargs - - -class FxOnnxInterpreter: - """Stateless class to process FX graph Nodes and translate them into their ONNX counterparts. - - All FX nodes described by [FX Graph](https://pytorch.org/docs/stable/fx.html#torch.fx.Graph) are supported. - Similarly to [FX Interpreter pattern](https://pytorch.org/docs/stable/fx.html#torch.fx.Interpreter), each FX node - must be implemented on its own method in this class. - - Each operator's implementation returns either an `onnxscript.OnnxFunction` or - `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm. They can - also raise RuntimeError: If there are no overloaded functions available for the given FX node. - """ - - def run_node( - self, - node, - fx_graph_module: torch.fx.GraphModule, - onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, - onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, - onnxscript_tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, - fx_name_to_onnxscript_value: dict[ - str, - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - ], - ): - """Execute a single FX node to produce its ONNX counterpart. - - Args: - node: The FX node to be translated. - fx_graph_module: The FX graph module containing the node. - onnxfunction_dispatcher: The dispatcher to find the best matched ONNX op. - onnxscript_graph: The ONNX graph to be populated. - onnxscript_tracer: The tracer to trace the ONNX graph. - fx_name_to_onnxscript_value: The mapping from FX node name to ONNX Script value. - - Raises: - RuntimeError: When a node.op is not supported. - """ - if node.op == "placeholder": - self.placeholder(node, onnxscript_graph, fx_name_to_onnxscript_value) - elif node.op == "get_attr": - self.get_attr( - node, - onnxscript_graph, - fx_name_to_onnxscript_value, - fx_graph_module, - ) - elif node.op == "call_function": - self.call_function( - node, - onnxscript_tracer, - fx_name_to_onnxscript_value, - onnxfunction_dispatcher, - fx_graph_module, - ) - elif node.op == "call_method": - self.call_method(node) - elif node.op == "call_module": - self.call_module( - node, - onnxscript_graph, - fx_name_to_onnxscript_value, - onnxscript_tracer, - fx_graph_module, - onnxfunction_dispatcher, - ) - elif node.op == "output": - self.output(node, onnxscript_graph, fx_name_to_onnxscript_value) - else: - raise RuntimeError(f"Found node type not defined in torch.fx: {node.op}") - - def run( - self, - fx_graph_module: torch.fx.GraphModule, - onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, - parent_onnxscript_graph: onnxscript_graph_building.TorchScriptGraph - | None = None, - ) -> onnxscript_graph_building.TorchScriptGraph: - """Analyze all FX nodes and trigger their ONNX translation. - - Args: - fx_graph_module: FX graph module to be translated. - onnxfunction_dispatcher: ONNX function dispatcher. - parent_onnxscript_graph: The parent TorchScript graph. Must be provided if - `fx_graph_module` is a submodule. If not provided, - `fx_graph_module` is assumed to be the root module. - """ - if parent_onnxscript_graph is not None: - # If parent_onnxscript_graph is provided, we assume fx_graph_module is a - # submodule representing a forward call of an nn.Module. - # Compose package and version where the nn.Module is defined as domain name - # for the local function. - - onnx_meta: _pass.GraphModuleOnnxMeta | None = fx_graph_module.meta.get( - "onnx" - ) - if onnx_meta is None: - raise RuntimeError( - f"ONNX meta is not found in submodule {fx_graph_module._get_name()}. " - f"Only submodules produced by `Modularize` pass is supported in ONNX export." - ) - - onnx_domain = onnx_meta.package_info.to_onnx_domain_string() - else: - # Leave as default domain name for the root module. - onnx_domain = None - - onnxscript_graph = onnxscript_graph_building.TorchScriptGraph( - parent_onnxscript_graph, domain_name=onnx_domain - ) - onnxscript_tracer = onnxscript_graph_building.TorchScriptTracingEvaluator( - onnxscript_graph - ) - # In the following loop, a TorchScript graph is created to - # represent the input FX graph with ONNX symbols (e.g., onnx::add). - # To connect the values to nodes in the TorchScript graph, we maintain - # fx_name_to_onnxscript_value. Basically, we want to translate - # fx_tensor_x (type: torch.fx.Node) -> fx_node_1 -> fx_tensor_y (type: torch.fx.Node) - # to - # fx_name_to_onnxscript_value[fx_tensor_x.name] -> onnx_node_1 -> fx_name_to_onnxscript_value[fx_tensor_y.name] - fx_name_to_onnxscript_value: dict[ - str, - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - ] = {} - - # TODO: Fix FakeTensorMode limitation asap - # We want to pass list of ints and floats to TorchScript graph correctly - # in _export_fx_to_ts, so we must disable FakeTensorMode. Otherwise, graph may - # receive FakeTensor and results runtime error. In addition, TorchScript-based - # ONNX exporter used in _ts_graph_to_onnx_model_in_protobuf is not compatible - # with FakeTensorMode. - with torch.utils._mode_utils.no_dispatch(): - for node in fx_graph_module.graph.nodes: - self.run_node( - node, - fx_graph_module, - onnxfunction_dispatcher, - onnxscript_graph, - onnxscript_tracer, - fx_name_to_onnxscript_value, - ) - - return onnxscript_graph - - def placeholder( - self, - node: torch.fx.Node, - onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, - fx_name_to_onnxscript_value: dict[ - str, - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - ], - ): - # Input of graph. - # The node.meta["val"] is generated by FakeTensorProp. - # NOTE: add_input() intends to create nodes with shape/type - fake_tensor = node.meta.get("val", None) - # NOTE: During the tracing, when inputs are constants, they are represented - # by nodes with node.meta['val'] being None (nn.Module to dynamo_export) - # or nodes with node.meta['val'] being a builtin value (ExportedProgram to dynamo_export). - # Nonethless, the nodes are not consumed by others, so we don't need to - # create a TorchScriptTensor for them. - if fake_tensor is None or isinstance(fake_tensor, (int, float, bool, str)): - output = onnxscript_graph.add_input( - input_name=None, - ) - elif isinstance(fake_tensor, torch.Tensor): - # NOTE: ONNX doesn't support tensor of complex64/complex128, so we - # convert them to float32/float64 with real representation. - if fx_type_utils.is_torch_complex_dtype(fake_tensor.dtype): - fake_tensor = torch.view_as_real(fake_tensor.resolve_conj()) - output = onnxscript_graph.add_input( - input_name=node.name, - shape=fake_tensor.shape, - dtype=fake_tensor.dtype, - ) - - elif fx_type_utils.is_torch_symbolic_type(fake_tensor): - output = onnxscript_graph.add_input( - input_name=node.name, - shape=torch.Size([]), - dtype=fx_type_utils.from_sym_value_to_torch_dtype(fake_tensor), - ) - else: - raise RuntimeError( - f"Unsupported type(node.meta['val']) for placeholder: {type(fake_tensor)}" - ) - assert output is not None, ( - f"Node creates None with target={node.target} and name={node.name}" - ) - - assert isinstance(output, onnxscript_graph_building.TorchScriptTensor) - assert isinstance(output, onnxscript.tensor.Tensor) - - fx_name_to_onnxscript_value[node.name] = output - - def call_function( - self, - node: torch.fx.Node, - onnxscript_tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, - fx_name_to_onnxscript_value: dict[ - str, - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - ], - onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, - fx_graph_module: torch.fx.GraphModule, - ): - # aten ops and other stateless functions. - if node.target == operator.getitem and isinstance( - fx_name_to_onnxscript_value[node.args[0].name], # type: ignore[union-attr,index] - tuple, - ): - onnx_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] # type: ignore[union-attr,index] - index = node.args[1] - value = onnx_tensor_tuple[index] # type: ignore[index] - assert value is not None, ( - f"Node creates None with target={node.target} and name={node.name}" - ) - assert isinstance( - value, (onnxscript_graph_building.TorchScriptTensor, tuple) - ), type(value) - - fx_name_to_onnxscript_value[node.name] = value - return - - # Map FX inputs to ONNX inputs and fill optional inputs with default values. - # torch_args and torch_kwargs are for op-level validation - fx_args, fx_kwargs = _fill_in_default_kwargs(node) - - onnx_args, onnx_kwargs = _wrap_fx_args_as_onnxscript_args( - fx_args, - fx_kwargs, - fx_name_to_onnxscript_value, - onnxscript_tracer, - ) - # Dispatch to ONNX op through OpShema. The input argument dtypes are compared to - # function signature in OpSchema, and find the best matched overload. - symbolic_fn = onnxfunction_dispatcher.dispatch( - node=node, - onnx_args=onnx_args, # type: ignore[arg-type] - onnx_kwargs=onnx_kwargs, - ) - with onnxscript.evaluator.default_as(onnxscript_tracer): - output: ( - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...] - ) = symbolic_fn(*onnx_args, **onnx_kwargs) - assert output is not None, ( - f"Node creates None with target={node.target}, name={node.name}, args={onnx_args}, kwargs={onnx_kwargs}" - ) - # Assign type and shape from fx graph. - _fill_tensor_shape_type(output, node.name, node.meta["val"]) - # One fx node could produce multiple outputs (e.g., tuple of tensors); in - # that case, v is a tuple of TorchScriptTensors. - assert isinstance( - output, (onnxscript_graph_building.TorchScriptTensor, tuple) - ), type(output) - fx_name_to_onnxscript_value[node.name] = output - - def output( - self, - node: torch.fx.Node, - onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, - fx_name_to_onnxscript_value: dict[ - str, - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - ], - ): - if isinstance(node.args[0], torch.fx.Node): - onnx_tensor_or_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] - onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple) - else: - # ONNX can't represent collection types (e.g., dictionary, tuple of tuple of - # tensor, etc), we flatten the collection and register each element as output. - flat_args, _ = _pytree.tree_flatten(node.args[0]) - for arg in flat_args: - assert isinstance(arg, torch.fx.Node), ( - f"arg must be a torch.fx.Node, not {type(arg)}" - ) - onnx_tensor_or_tensor_tuple = fx_name_to_onnxscript_value[arg.name] - onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple) - - def call_method(self, node: torch.fx.Node): - # TODO(wechi): Support call_method. - raise RuntimeError("call_method is not supported yet.") - - def call_module( - self, - node: torch.fx.Node, - parent_onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, - fx_name_to_onnxscript_value: dict[ - str, - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - ], - tracer: onnxscript_graph_building.TorchScriptTracingEvaluator, - root_fx_graph_module: torch.fx.GraphModule, - onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher, - ) -> None: - """Export a fx.GraphModule submodule to ONNXScript graph. - - The export process specifically targets `call_module` nodes that are created by - the exporter's `Modularize` pass. Each `call_module` node has an associated fx.GraphModule - by `node.target` underneath the root fx.GraphModule. These `call_module` nodes are exported as ONNX - function nodes. The related `sub_module` is then exported as an ONNX model local function, - which is represented by another `TorchScriptGraph`. This `TorchScriptGraph` sets the current - `onnxscript_graph` as its parent. - - Args: - node: The call_module node in the FX graph that represents the submodule call. - parent_onnxscript_graph: The parent ONNXScript graph to which the ONNX function and - function node belong. - fx_name_to_onnxscript_value: The mapping from FX node name to ONNXScript value. - tracer: The tracer used to trace the ONNXScript graph. - root_fx_graph_module: The root FX module. - onnxfunction_dispatcher: The dispatcher. - """ - assert isinstance(node.target, str), ( - f"node.target must be a str, not {type(node.target)} for node {node}." - ) - - sub_module = root_fx_graph_module.get_submodule(node.target) - - assert isinstance(sub_module, torch.fx.GraphModule), ( - f"sub_module must be a torch.fx.GraphModule, not {type(sub_module)} for node {node}." - ) - - sub_onnxscript_graph = self.run( - sub_module, onnxfunction_dispatcher, parent_onnxscript_graph - ) - - onnx_args, _ = _wrap_fx_args_as_onnxscript_args( - list(node.args), {}, fx_name_to_onnxscript_value, tracer - ) - - # TODO: We may want to consider other naming styles. The goal is to be stable and - # unique such that it can be easily identified in case of kernel substitution. - # Example for current style is combination of qualified module class name and - # module attribute name: `torch_nn_modules_conv_Conv2d_conv1`. - # Other naming styles such as qualified module class name made unique can also - # be considered. - unique_module_name = f"{sub_module._get_name()}_{node.target}" - - outputs: ( - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...] - ) = parent_onnxscript_graph.add_module_call( # type: ignore[assignment] - unique_module_name, sub_onnxscript_graph, onnx_args - ) - - assert isinstance( - outputs, (onnxscript_graph_building.TorchScriptTensor, tuple) - ), f"Unexpected outputs type {type(outputs)} for node {node}." - - _fill_tensor_shape_type(outputs, node.name, node.meta["val"]) - fx_name_to_onnxscript_value[node.name] = outputs - - # Skip op_level_validation for call_module. Subgraph nodes are validated individually. - - def get_attr( - self, - node: torch.fx.Node, - onnxscript_graph: onnxscript_graph_building.TorchScriptGraph, - fx_name_to_onnxscript_value: dict[ - str, - onnxscript_graph_building.TorchScriptTensor - | tuple[onnxscript_graph_building.TorchScriptTensor, ...], - ], - fx_graph_module: torch.fx.GraphModule, - ): - assert isinstance(node.target, str), f"node.target {node.target} is not a str." - attr_tensor = getattr(fx_graph_module, node.target) - assert isinstance(attr_tensor, torch.Tensor), f"{attr_tensor} is not a tensor." - - # Parameter/buffer name cannot contain "." - # Revert from "/" to restore namespace formatting. - input_ = onnxscript_graph.add_initializer( - name=node.target.replace("/", "."), - value=attr_tensor, - ) - - assert isinstance(input_, onnxscript_graph_building.TorchScriptTensor) - assert isinstance(input_, onnxscript.tensor.Tensor) - fx_name_to_onnxscript_value[node.name] = input_ diff --git a/torch/onnx/_internal/fx/onnxfunction_dispatcher.py b/torch/onnx/_internal/fx/onnxfunction_dispatcher.py deleted file mode 100644 index 516eb3636888..000000000000 --- a/torch/onnx/_internal/fx/onnxfunction_dispatcher.py +++ /dev/null @@ -1,731 +0,0 @@ -# mypy: allow-untyped-defs -"""Dispatcher for AtenLib functions from onnx-script. - -This is a deprecated module to be removed. -""" - -from __future__ import annotations - -import logging -import operator -import types -from typing import Any, TYPE_CHECKING - -import torch -import torch._ops -import torch.fx -from torch.onnx._internal.fx import registration, type_utils as fx_type_utils - - -if TYPE_CHECKING: - from collections.abc import Sequence - - import onnxscript # type: ignore[import] - from onnxscript.function_libs.torch_lib import ( # type: ignore[import] - graph_building as onnxscript_graph_building, - ) - - from torch.onnx._internal._exporter_legacy import OnnxRegistry - - -logger = logging.getLogger(__name__) - - -class OnnxFunctionDispatcher: - """A dispatcher that finds the best ONNX Function for ATen/Custom operators. - - It uses the `torch.ops` name to find the function. If not found, it falls back to default. - Otherwise, the best match is found among all function overloads. An exact match has - higher precedence over the closest ones. - - Below is a breakdown on how the dispatch mechanism works: - - 1. Use the torch.ops name to find the function: - a. Check if the ATen overload exists in the registry. - b. If not, check if the default overload exists in the registry. - - 2. Find the nearest match among all overloaded functions: - a. If the types match perfectly, select the function. - b. Otherwise, find the nearest one with the highest matching score. Because of - the potential wrongly annotated dtypes and attributes matching, we use - nearest match to find the best function once the aten name is targeted. - - 3. Tie-breaker: If there are multiple nearest matches, we will select the one with - the highest matching score. - - NOTE: The nearest match `doesn't guarantee` a correct match, and a warning message is logged. - """ - - def __init__( - self, - onnx_registry: OnnxRegistry, - ): - """Initialize the ONNX Function dispatcher. - - Args: - onnx_registry: The ONNX registry. - """ - self.onnx_registry = onnx_registry - - def dispatch( - self, - node: torch.fx.Node, - onnx_args: Sequence[ - fx_type_utils.TensorLike | str | int | float | bool | list | complex | None - ], - onnx_kwargs: dict[str, fx_type_utils.Argument], - ) -> onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction: - """Dispatches an ONNX function based on the given FX node, arguments, and keyword arguments. - Args: - node: The TorchFX node to dispatch the function for. - onnx_args: The arguments of the ONNX function. - onnx_kwargs: The keyword arguments of the ONNX function. - - Returns: - Either an `onnxscript.OnnxFunction` or `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm. - Raises: - RuntimeError: If there are no overloaded functions available for the given FX node. - """ - # If there are no overloaded functions available for the given FX node, raise an - # unsupported error - default_and_custom_functions = self.get_function_overloads(node) - - # If there are overloaded functions available, we will find one that perfect or - # nearest matches the given arguments and keyword arguments - return self._find_the_perfect_or_nearest_match_onnxfunction( - node, - default_and_custom_functions, - onnx_args, - onnx_kwargs, - ) - - def _filter_or_keep_complex( - self, - node, - default_and_custom_functions: list[registration.ONNXFunction], - ) -> list[registration.ONNXFunction]: - """Filter the complex functions if the input has complex dtype.""" - - args_with_complex_dtype = [_is_arg_with_complex_dtype(arg) for arg in node.args] - if any(args_with_complex_dtype): - default_and_custom_functions = [ - func for func in default_and_custom_functions if func.is_complex - ] - # If we can't find the complex function group, raise error. - if not default_and_custom_functions: - op_full_name = self._get_aten_name(node).qualified_name() - raise RuntimeError( - f"Cannot find any COMPLEX symbolic function for {op_full_name}, " - f"which should be registered under {node.target}.", - ) - else: - default_and_custom_functions = [ - func for func in default_and_custom_functions if not func.is_complex - ] - # If we can't find the complex function group, raise error. - if not default_and_custom_functions: - op_full_name = self._get_aten_name(node).qualified_name() - raise RuntimeError( - f"Can ONLY find COMPLEX symbolic function for {op_full_name}, " - f"which should be registered under {node.target}.", - ) - return default_and_custom_functions - - def _find_the_perfect_or_nearest_match_onnxfunction( - self, - node: torch.fx.Node, - default_and_custom_functions: list[registration.ONNXFunction], - onnx_args: Sequence[ - fx_type_utils.TensorLike | str | int | float | bool | list | complex | None - ], - onnx_kwargs: dict[str, fx_type_utils.Argument], - ): - """Find the perfect/nearest matched OnnxFunction for the given FX node, arguments, and keyword arguments. - - Args: - default_and_custom_functions: The list includes overloaded functions, with - custom ones appearing after the default ones. - onnx_args: Arguments organized in PyTorch inputs way. - onnx_kwargs: Keyword arguments organized in PyTorch inputs way. - - Returns: - Either an `onnxscript.OnnxFunction` or `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm. - Raises: - RuntimeError: If there are no overloaded functions available for the given FX node. - """ - overload_match_ranking: dict[registration.ONNXFunction, int | None] = {} - - # Iterate the overloaded functions in reverse order to prioritize the custom ones - # over the default ones, and find the perfect match. - for symbolic_function in reversed(default_and_custom_functions): - function_opschema = _OnnxSchemaChecker(symbolic_function.onnx_function) - - # NOTE: 1. If the perfect match is found, return the function - if function_opschema.perfect_match_inputs(onnx_args, onnx_kwargs): - return symbolic_function.onnx_function - # Record the match score for the nearest match if it's not the perfect match - overload_match_ranking[symbolic_function] = function_opschema.match_score - - # NOTE: 2. If there is no perfect match, find the nearest match among the nearest matche candidates - # If there is no nearest match, raise an error - overload_match_ranking = { - k: v for k, v in overload_match_ranking.items() if v is not None - } - if not overload_match_ranking: - # If there are no overloaded functions available for the given FX node, raise an - # unsupported error - op_full_name = self._get_aten_name(node).qualified_name() - raise RuntimeError( - f"Cannot find any perfect/nearest match of symbolic function for {op_full_name}," - f"which should be registered under {node.target}.", - ) - - # NOTE: 3. Tie breaker: if there are multiple nearest matches, we will choose the one - # that is custom first. If there are multiple custom ones, we will choose the one - # that is added lastly in the list. - symbolic_function_list: list[registration.ONNXFunction] = sorted( - overload_match_ranking, - key=lambda k: ( - overload_match_ranking[k], - k.is_custom, - default_and_custom_functions.index(k), - ), - reverse=True, - ) - return symbolic_function_list[0].onnx_function - - def _get_aten_name(self, node: torch.fx.Node) -> registration.OpName: - """Get the OpName from the target. - - Args: - node: The TorchFX node to get the aten name for. - - Returns: - The internal op name within dataclass: registration.OpName. - """ - if node.target == operator.getitem: - return registration.OpName.from_name_parts( - namespace="aten", op_name="getitem" - ) - if isinstance(node.target, torch._ops.OpOverloadPacket): - # aten::sym_size is the only OverloadPacket that we support. - # schema: aten::sym_size(Tensor self, int dim) -> Tensor - if node.target != torch.ops.aten.sym_size: - raise RuntimeError( - f"Unsupported OverloadPacket: {node.target}, aten.sym_size is the only allowed OverloadPacket!", - ) - # TODO(titaiwang): aten::sym_size has overload, but fx graph is using - # overloadpacket for some reasons. - # https://github.com/pytorch/pytorch/issues/97201 - aten_op_default = node.target.default - return registration.OpName.from_op_overload(op_overload=aten_op_default) # type: ignore[no-any-return] - - if isinstance(node.target, types.BuiltinFunctionType): - # Make sure it's symint/symfloat consuming builtin ops. - for node_arg in node.args: - if (not isinstance(node_arg, (torch.fx.Node, int, float))) or ( - isinstance(node_arg, torch.fx.Node) - and not fx_type_utils.is_torch_symbolic_type(node_arg.meta["val"]) - ): - raise RuntimeError( - f"Unsupported node arg: {node_arg} (type {type(node_arg)}) with builtin function: {node.target}," - " only int/float/SymInt/SymFloat is supported with built-in ops!", - ) - return registration.OpName.from_builtin_function(node.target) - - if isinstance(node.target, torch._ops.OpOverload): - return registration.OpName.from_op_overload(op_overload=node.target) - - # Unexpected target, raise error. - raise RuntimeError(f"Unknown call_function target: {node.target}") - - def get_function_overloads( - self, - node: torch.fx.Node, - ) -> list[registration.ONNXFunction]: - """Get the function overloads from the registry. - - Args: - node: The node to get the function overloads for. - - Returns: - The list contains ONNXFunctions, starting with the default ones and - followed by any custom ones. - """ - - internal_opname: registration.OpName = self._get_aten_name(node=node) - - # If the ATen/Custom operators are not registered, the group will be None. - # And non-registered ATen/Custom operators will trigger error in the next step. - function_group: list[registration.ONNXFunction] | None = None - - function_group = self.onnx_registry.get_op_functions( - namespace=internal_opname.namespace, - op_name=internal_opname.op_name, - overload=internal_opname.overload, - ) - - # NOTE: Fall back to default overload if the ONNX registry doesn't have the overload. - if function_group is None: - function_group = self.onnx_registry.get_op_functions( - namespace=internal_opname.namespace, - op_name=internal_opname.op_name, - overload=None, - ) - if function_group is not None: - op_full_name = internal_opname.qualified_name() - - if function_group is not None: - # NOTE: If the input has complex dtype, we will only dispatch to the complex functions. - function_group = self._filter_or_keep_complex(node, function_group) - return function_group # type: ignore[return-value] - - op_full_name = internal_opname.qualified_name() - raise RuntimeError( - f"Cannot find symbolic function for {op_full_name}, " - f"which should be registered under {node.target}.", - ) - - -class _OnnxSchemaChecker: - """ - The OnnxSchemaChecker class is a checker for ONNX OpSchema and param schema. - - It provides methods to check for input compatibility based on the OpSchema. It also - provides a matching score to indicate how well the OpSchema matches the input and - kwargs types. A function will be evaluated as perfect match, nearest match eligible, - or no match. - - Here are some common examples in categories: - - 1. [NOTE: Perfect match]: The number of inputs and attributes are exactly the same as - the OpSchema. The types of inputs and attributes are exactly the same as the - OpSchema. - - ```python - inputs = (Tensor[2, 3], Tensor[2, 3]) - attributes = {"alpha": 1.0} - - - @torch_op("aten::op") - def aten_op(self: TReal, other: TReal, alpha: float = 1) -> TReal: ... - ``` - Result: Perfect match. - - 2. [NOTE: Optional input]: The dispatcher recognizes optional inputs. However, - the input can't be ignored. None must be provided. - - ```python - inputs = (Tensor([2, 3]), None) - attributes = {} - - aten_op(X: TTensor, Y: Optional[INT64]): - ... - ``` - Result: Perfect match. - Real example: `aten::convolution`. - - 3. [NOTE: Different attributes]: If an attribute is provided with value, it's - a must to match the attribute in function signature. - ```python - inputs = (Tensor([2, 3]),) - attributes = {"a":1, "b":2} - - aten_op(X: TTensor, a: int): - ... - ``` - Result: No match. - Real example: `aten::div` vs `aten::div.Tensor_mode`. - - 4. [NOTE: Default attributes]: Default attribute will fill in the value into - inputs/attributes. - ```python - inputs = (Tensor([2, 3]),) - attributes = {} - - aten_op(X: TTensor, a: int = 3): - ... - ``` - Result: Perfect match. - Real example: `aten::clone` - - 5. [NOTE: Ignore attribute with None value]: The attributes with None value - will be ignored in matching. - ```python - inputs = (Tensor([2, 3]),) - attributes = {"a": None} - - aten_op(X: TTensor): - ... - ``` - Result: Perfect match. - - ```python - inputs = (Tensor([2, 3]),) - attributes = {"a": None} - - aten_op(X: TTensor, a: int = 3): - ... - ``` - Result: Nearest match eligible. - - Real example: `aten::div` vs `aten::div.Tensor_mode`. - - Attributes: - onnxfunction: The OnnxFunction. - param_schema: The parameter schema defined in the OnnxFunction. - op_schema: The ONNX OpSchema. - type_constraints: The type constraints defined in the OpSchema. - attributes: The attributes defined in the OpSchema. - _matching_score: The matching score of the OnnxSchemaChecker . - - """ - - def __init__( - self, - onnxfunction: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction, - ): - """Initialize the OnnxSchemaChecker . - - Args: - onnxfunction: The OnnxFunction. - """ - self.onnxfunction = onnxfunction - self.param_schema = self.onnxfunction.param_schemas() - op_schema = self.onnxfunction.op_schema - # Both `OnnxFunction` and `TracedOnnxFunction` never return None for `op_schema`. - # However their base class would. Hence return type is annotated as Optional[OpSchema]. - assert op_schema is not None - self.op_schema = op_schema - self.type_constraints = { - # "T": {"tensor(int64)"} - constraint.type_param_str: set(constraint.allowed_type_strs) - for constraint in self.op_schema.type_constraints - } - self.attributes = self.op_schema.attributes - self._matching_score: int | None = None - - @property - def match_score(self) -> int | None: - """The matching score of the OnnxSchemaChecker . - - If this remains None, it means the matching score has not been calculated, - and it's not a nearest match candidate. - - Returns: - The matching score of the OnnxSchemaChecker . - """ - return self._matching_score - - def perfect_match_inputs( - self, - args: Sequence[ - fx_type_utils.TensorLike | str | int | float | bool | list | complex | None - ], - kwargs: dict[str, fx_type_utils.Argument], - ) -> bool: - """Check if the inputs perfectly match the OpSchema requirements. - - The definition of perfect match is that the input types are all in the type - constraints and the number of inputs matches the number of inputs in the - OpSchema. - - Checking steps: - 1. The function signature matches the inputs number, and attribute names. - 2. The input/attribute types are all in the type constraints. - - A function should at least pass the first step to be eligible for the - nearest matching. - - Args: - args: The input arguments organized in PyTorch inputs way. - kwargs: The input keyword arguments organized in PyTorch inputs way. - - Returns: - True if the inputs match the requirements, False otherwise. - """ - - # NOTE: OnnxFunction does not have the same function signature as the original - # PyTorch operator. We need to separate the input/attributes from the arguments. - ( - function_inputs, - function_attributes, - ) = self._separate_input_attributes_from_arguments( - self.param_schema, - args, - kwargs, - fill_defaults=True, # fill defaults for optional arguments to match - ) - # NOTE: 1. Check if the input number and attribute names match the - # OpSchema. If it's not, we know the function is not eligible to be a perfect - # match, nor a nearest match. - # We use is_perfect_match to postpone the return value to the end - # of the function, as we want to log all the mismatch info. - is_perfect_match = True - if len(function_inputs) != len(self.op_schema.inputs): - logger.info( - "Actual %d vs expected %d", - len(function_inputs), - len(self.op_schema.inputs), - ) - logger.info("The function is not a nearest match candidate.") - is_perfect_match = False - - if set(function_attributes) != set(self.attributes): - logger.info("The function is not a nearest match candidate.") - is_perfect_match = False - - # If it's already not a perfect match, we can return False directly. Further - # checking is only for the functions that are eligible for nearest match. - if not is_perfect_match: - return False - - # NOTE: 2. The dtypes of inputs and attributes should be in the - # type constraints of the OpSchema. If they are not, we know the function is not - # eligible to be a perfect match, but can be a nearest match candidate. - for schema_input, torch_input in zip(self.op_schema.inputs, function_inputs): - torch_input_compatible_types = _find_onnx_data_type(torch_input) - allowed_types = self.type_constraints[schema_input.type_str] - if not allowed_types.intersection(torch_input_compatible_types) and not any( - fx_type_utils.is_optional_onnx_dtype_str(onnx_type_str) - for onnx_type_str in allowed_types - ): - # If torch_input_compatible_types isn't in allowed_types - # of this input defined in the OpSchema, we know the function - # and the input are not compatible - logger.info( - "Actual %s vs\nExpected %s", - torch_input_compatible_types, - allowed_types, - ) - is_perfect_match = False - - for attribute_name, attribute in function_attributes.items(): - if not self._match_onnx_attribute_type(attribute_name, attribute): - # If the attribute type of the OpSchema and the attribute type don't match, - # we know the function and the input are not compatible - logger.info( - "Actual %s vs\nExpected %s", - type(attribute), - self.attributes[attribute_name].type, - ) - is_perfect_match = False - - # NOTE: This is still a candidate for nearest match, as it only mismatches attributes on dtype. - self._record_matching_score(function_inputs, function_attributes) - logger.info("match score: %d", self.match_score) - return is_perfect_match - - def _match_onnx_attribute_type( - self, - attribute_name: str, - attribute: fx_type_utils.Argument | onnxscript_graph_building.TorchScriptTensor, - is_sequence: bool = False, - ) -> bool: - if isinstance(attribute, (int, float, bool, str)): - attribute_onnx_type = fx_type_utils.from_python_type_to_onnx_attribute_type( - type(attribute), is_sequence=is_sequence - ) - if attribute_onnx_type != self.attributes[attribute_name].type: - return False - # If the attribute is an empty list, we don't know the type of the list - # so it's a mismatch - elif isinstance(attribute, (list, tuple)) and attribute: - return self._match_onnx_attribute_type( - attribute_name, attribute[0], is_sequence=True - ) - else: - # NOTE: Unrecognized attribute type - return False - return True - - def _record_matching_score( - self, - inputs: Sequence[ - fx_type_utils.TensorLike | str | int | float | bool | list | complex | None - ], - attributes: dict[str, fx_type_utils.Argument], - ): - """Calculate the inputs matching score of the OpSchema requirements to find the nearest match. - - Only the functions which have the same number of inputs and attributes as the - OpSchema are eligible to be a nearest match candidate. Thus, we don't need to - check the length of inputs and attributes here, and only check the types of - inputs and attributes. - - How the matchsing score is calculated: - score += 1 if one input/attribute type is in the type constraints. - - Limitations: - None/NoeType/[] could result in zero matches, and the same score of overloads. - - Args: - inputs: The input arguments. - attributes: The input keyword arguments. - - Returns: - True if the inputs match the requirements, False otherwise. - """ - self._matching_score = 0 - # If they have different length of arguments, the score would be lower to those - # functions which have the same length of arguments. - for schema_input, torch_input in zip(self.op_schema.inputs, inputs): - torch_input_compatible_types = _find_onnx_data_type(torch_input) - allowed_types = self.type_constraints[schema_input.type_str] - if allowed_types.intersection(torch_input_compatible_types): - # If torch_input_compatible_types is in allowed_types - # of this input defined in the OpSchema, we know the function - # and the input are compatible - self._matching_score += 1 - # NOTE: The penalty is applied to those functions which have different attributes. - for attribute_name, attribute_proto in self.attributes.items(): - attribute = attributes[attribute_name] - attribute_onnx_type = fx_type_utils.from_python_type_to_onnx_attribute_type( - type(attribute) - ) - if attribute_onnx_type != attribute_proto.type: - # If the attribute type of the OpSchema and the attribute type don't match, - # we know the function and the input are not compatible - self._matching_score -= 1 - - # NOTE: Referenced from onnxscript internal function. - # Importing this function makes the code less robust, as it is not a public API. - - def _separate_input_attributes_from_arguments( - self, - param_schemas: Sequence[onnxscript.values.ParamSchema], - args: Sequence[ - fx_type_utils.TensorLike | str | int | float | bool | list | complex | None - ], - kwargs: dict[str, fx_type_utils.Argument], - fill_defaults: bool = True, - ) -> tuple[list[Any], dict[str, Any]]: - """Separate Python args and kwargs into ONNX inputs and attributes. - - Extra_kwargs are ignored if their values are None. For example, if the - OpSchema has an attribute "rounding_mode" and the caller provides - "rounding_mode=None", the attribute "rounding_mode" will not be included - in the returned attributes when the OnnxFunction signature doesn't have - "rounding_mode" as an attribute. - - Args: - param_schemas: The parameter schemas of an Op or a OnnxFunction. - args: The Python positional arguments supplied by the caller. - kwargs: The Python keyword arguments supplied by the caller. - fill_defaults: Whether to fill the default values for attributes. - - Returns: - A tuple of two elements: - - A list of ONNX inputs. - - An dictionary of ONNX attribute names and values. - - Raises: - TypeError: When allow_extra_kwargs is False and there are unknown kwargs. - TypeError: When a required input is not provided. - """ - # args, kwargs and param_schemas should be all in order - # user may not specify all inputs or attributes - - import onnx - - onnx_inputs: list[Any] = [] - onnx_attributes: dict[str, Any] = {} - # NOTE: We need to copy kwargs because we will mutate it - copy_kwargs = kwargs.copy() - for i, param in enumerate(param_schemas): - if param.is_variadic_input: - # Exhaust all remaining args - onnx_inputs.extend(args[i:]) - args = [] - continue - if i < len(args): - if param.is_input: - onnx_inputs.append(args[i]) - else: - onnx_attributes[param.name] = args[i] - elif param.name in copy_kwargs: - if param.is_input: - # Move the input from kwargs to inputs - onnx_inputs.append(copy_kwargs[param.name]) - copy_kwargs.pop(param.name) - else: - onnx_attributes[param.name] = copy_kwargs[param.name] - elif ( - param.is_attribute - and self.attributes[param.name].default_value.type - != onnx.AttributeProto.UNDEFINED # type: ignore[attr-defined] - ): - # User did not provide the attribute - if fill_defaults: - onnx_attributes[param.name] = param.default - # optional input - elif param.is_input: - if fill_defaults: - onnx_inputs.append(None) - - # NOTE: Pick up extra kwargs if it's not None. None is not expected - # as an attribute value in torchlib. - for k, v in copy_kwargs.items(): - if k not in onnx_attributes and v is not None: - onnx_attributes[k] = v - return onnx_inputs, onnx_attributes - - -def _is_arg_with_complex_dtype(arg: fx_type_utils.Argument) -> bool: - """Check if the node has complex dtype recursively.""" - if ( - isinstance(arg, torch.fx.Node) - and "val" in arg.meta - and isinstance(arg.meta["val"], torch.Tensor) - and torch.is_complex(arg.meta["val"]) - ): - return True - elif isinstance(arg, list): - for item in arg: - return _is_arg_with_complex_dtype(item) - return False - - -def _find_onnx_data_type( - torch_input: fx_type_utils.TensorLike - | str - | int - | float - | bool - | list - | tuple - | complex - | None, -) -> set[str]: - """Convert inputs data type from torch acceptable dtype to the compatible onnx dtype string.""" - if ( - isinstance(torch_input, fx_type_utils.TensorLike) - and torch_input.dtype is not None - ): - return fx_type_utils.from_torch_dtype_to_onnx_dtype_str(torch_input.dtype) - if isinstance(torch_input, (int, float, bool, str, complex)): - return fx_type_utils.from_torch_dtype_to_onnx_dtype_str(type(torch_input)) - if isinstance(torch_input, (list, tuple)) and torch_input: # [Tensor, Tensor] - the_first_non_none_item = next( - (item for item in torch_input if item is not None), None - ) - set_dtype = _find_onnx_data_type(the_first_non_none_item) - if any(isinstance(input, fx_type_utils.TensorLike) for input in torch_input): - # NOTE: Any Tensor involved in a list would make it a seq(tensor(onnx_type)) - return {f"seq({dtype})" for dtype in set_dtype} - else: - # constant list of non-tensor type - return set_dtype - if ( - torch_input is None - or ( - isinstance(torch_input, fx_type_utils.TensorLike) - and torch_input.dtype is None - ) - or (isinstance(torch_input, (list, tuple)) and not torch_input) - ): - # NOTE: None, No dtype, and empty list are edge cases, we allow it to be any type to relax the type check - # seq(tensor) also goes to here, as it is not supported in torchscript, and it would be None in this case. - return set() - - raise RuntimeError(f"Unknown input type from input: {torch_input}") diff --git a/torch/onnx/_internal/fx/passes/__init__.py b/torch/onnx/_internal/fx/passes/__init__.py index aa04e6beb5f1..eff83563a5a0 100644 --- a/torch/onnx/_internal/fx/passes/__init__.py +++ b/torch/onnx/_internal/fx/passes/__init__.py @@ -1,18 +1,6 @@ -from .decomp import Decompose -from .functionalization import Functionalize, RemoveInputMutation -from .modularization import Modularize -from .readability import RestoreParameterAndBufferNames from .type_promotion import InsertTypePromotion -from .virtualization import MovePlaceholderToFront, ReplaceGetAttrWithPlaceholder __all__ = [ - "Decompose", "InsertTypePromotion", - "Functionalize", - "Modularize", - "MovePlaceholderToFront", - "RemoveInputMutation", - "RestoreParameterAndBufferNames", - "ReplaceGetAttrWithPlaceholder", ] diff --git a/torch/onnx/_internal/fx/passes/decomp.py b/torch/onnx/_internal/fx/passes/decomp.py deleted file mode 100644 index 1573264d6fc7..000000000000 --- a/torch/onnx/_internal/fx/passes/decomp.py +++ /dev/null @@ -1,87 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -import contextlib -from typing import Callable, TYPE_CHECKING - -import torch -import torch._ops -from torch._dispatch import python as python_dispatch -from torch._subclasses import fake_tensor -from torch.fx.experimental import proxy_tensor -from torch.onnx._internal.fx import _pass -from torch.onnx._internal.fx.passes import _utils - - -if TYPE_CHECKING: - from collections.abc import Mapping - - import torch.fx - - -class Decompose(_pass.Transform): - def __init__( - self, - module: torch.fx.GraphModule, - decomposition_table: Mapping[torch._ops.OpOverload, Callable], - enable_dynamic_axes: bool, - allow_fake_constant: bool | None = False, - ): - super().__init__(module) - self.decomposition_table = decomposition_table - self.enable_dynamic_axes = enable_dynamic_axes - self.allow_fake_constant = allow_fake_constant - - def _run(self, *args, **kwargs) -> torch.fx.GraphModule: - assert not kwargs, "kwargs is not supported in Decompose." - - # To preserve stack trace info after `make_fx`. - module = _utils.wrap_graph_module_for_node_meta_preservation(self.module) - - # fake mode use static size to trace the size of tensors. while symbolic - # mode generates aten::sym_size to dynamically trace the size of tensors. - - # e.g. fake mode: - # view: f32[3, 5, 20] = torch.ops.aten.view.default(x, [3, 5, 20]) - - # e.g. symbolic mode: - # sym_size = torch.ops.aten.sym_size(x, 0) - # sym_size_1 = torch.ops.aten.sym_size(x, 1) - # sym_size_2 = torch.ops.aten.sym_size(x, 2) - # sym_size_3 = torch.ops.aten.sym_size(x, 3) - # mul = sym_size_2 * sym_size_3; sym_size_2 = sym_size_3 = None - # view: f32[3, 5, 20] = torch.ops.aten.view.default(x, [sym_size, sym_size_1, mul]) - - # Mimic `torch._dynamo.export(aten_graph=True)` behavior in invoking `make_fx`. - # TODO: May need revisit for user fake mode export + dynamic shape scenario. - fake_mode: fake_tensor.FakeTensorMode | None = self.fake_mode - maybe_fake_args = self._maybe_fakefy_args(fake_mode, *args) - if fake_mode is not None: - # Using existing fake mode as context, signal `make_fx` that it does not need - # to create a new fake mode by passing tracing_mode as "real". - tracing_mode = "real" - else: - # Existing fake mode not found, signal `make_fx` to create one. - fake_mode = contextlib.nullcontext() # type: ignore[assignment] - tracing_mode = "symbolic" if self.enable_dynamic_axes else "fake" - - # Apply decomposition table to the input graph. - assert fake_mode is not None # for mypy - with ( - fake_tensor.unset_fake_temporarily(), - python_dispatch.enable_python_dispatcher(), - fake_mode, - ): - decomposed_module = proxy_tensor.make_fx( - module, - decomposition_table=self.decomposition_table, - tracing_mode=tracing_mode, - _allow_non_fake_inputs=True, - _allow_fake_constant=bool(self.allow_fake_constant), - )(*maybe_fake_args) - - # Rename placeholder targets to match the original module's signature since - # We don't want to map forward(x, y, z) to forward(arg0, arg1, arg2). - _utils.replace_placeholder_name_and_target(decomposed_module, self.module) - - return decomposed_module diff --git a/torch/onnx/_internal/fx/passes/functionalization.py b/torch/onnx/_internal/fx/passes/functionalization.py deleted file mode 100644 index fd8d3c7d48ac..000000000000 --- a/torch/onnx/_internal/fx/passes/functionalization.py +++ /dev/null @@ -1,152 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -import contextlib -from typing import Callable - -import torch -import torch._ops -import torch.func -import torch.fx -from torch._subclasses import fake_tensor -from torch.fx.experimental import proxy_tensor -from torch.onnx._internal.fx import _pass -from torch.onnx._internal.fx.passes import _utils -from torch.utils import _pytree as pytree - - -class Functionalize(_pass.Transform): - """Functionalize a GraphModule. - - This pass utilizes ``functionalization`` utility of ``torch._functorch`` to convert - a GraphModule into a functional form. The two main functionalities are (copied from - its documentations): - - * ``functionalization`` removes (intermediate) mutations and aliasing from a - function, while preserving the function's semantics. - - * ``functionalization`` also removes mutations (and views) that were performed - on function inputs. However to preserve semantics, functionalize will "fix up" the - mutations after the transform has finished running, by detecting if any tensor inputs - "should have" been mutated, and copying the new data back to the inputs if necessary. - For example, consider:: - - def fn(a, b): - a.add_(b) - return a - - For a call like `fn(x, y)`, the variable `x` outside is also mutated. Hence just - functionalizing is not enough for preserving the original semantics. A "special" - input mutation step needs to be inserted at the end.:: - - # After functionalization, without input mutation "fix up". - # This is not semantically the same. The variable outside the function call that - # was passed in as `a` is not mutated. - def fn(a, b): - new_a = a + b - return new_a - - # Functionalization with input mutation "fix up" that preserves semantics. - def fn(a, b): - new_a = a + b - - # Copying the new data back to the inputs - a.copy_(new_a) - - return new_a - - For ONNX inference, it is recommended to run ``RemoveInputMutation`` after this pass. - ``RemoveInputMutation`` removes the "fix up" nodes that were added by ``Functionalize``, - which are not needed for ONNX inference. - """ - - def __init__( - self, - module: torch.fx.GraphModule, - enable_dynamic_axes: bool, - allow_fake_constant: bool | None = False, - ): - super().__init__(module) - self.enable_dynamic_axes = enable_dynamic_axes - self.allow_fake_constant = allow_fake_constant - - def _functionalize(self, function: Callable) -> Callable: - # Working around a dispatcher issue with `torch.func.functionalize` when used - # together with `make_fx`. - # Ref: https://github.com/pytorch/pytorch/issues/99774#issuecomment-1527949391 - def wrapped(*inputs): - inputs_functional = pytree.tree_map_only( - torch.Tensor, torch._to_functional_tensor, inputs - ) - torch._enable_functionalization(reapply_views=True) - try: - out = function(*inputs_functional) - finally: - torch._disable_functionalization() - - flat_inputs_functional = pytree.tree_leaves(inputs_functional) - for input_functional in flat_inputs_functional: - if isinstance(input_functional, torch.Tensor): - torch._sync(input_functional) - pytree.tree_map(torch._sync, out) - out_unwrapped = pytree.tree_map(torch._from_functional_tensor, out) - return out_unwrapped - - return wrapped - - def _run(self, *args) -> torch.fx.GraphModule: - # To preserve stack trace info after `make_fx`. - module = _utils.wrap_graph_module_for_node_meta_preservation(self.module) - - functionalized_callable = self._functionalize(module) - - # Mimic `torch._dynamo.export(aten_graph=True)` behavior in invoking `make_fx`. - # TODO: May need revisit for user fake mode export + dynamic shape scenario. - fake_mode: fake_tensor.FakeTensorMode | None = self.fake_mode - maybe_fake_args = self._maybe_fakefy_args(fake_mode, *args) - if fake_mode is not None: - # Using existing fake mode as context, signal `make_fx` that it does not need - # to create a new fake mode by passing tracing_mode as "real". - tracing_mode = "real" - else: - # Existing fake mode not found, signal `make_fx` to create one. - fake_mode = contextlib.nullcontext() # type: ignore[assignment] - tracing_mode = "symbolic" if self.enable_dynamic_axes else "fake" - - assert fake_mode is not None # for mypy - with fake_tensor.unset_fake_temporarily(), fake_mode: - graph_module = proxy_tensor.make_fx( - functionalized_callable, - decomposition_table={}, - tracing_mode=tracing_mode, - _allow_non_fake_inputs=True, - _allow_fake_constant=bool(self.allow_fake_constant), - )(*maybe_fake_args) - - # Rename placeholder targets to match the original module's signature since - # We don't want to map forward(x, y, z) to forward(arg0, arg1, arg2). - _utils.replace_placeholder_name_and_target(graph_module, self.module) - - return graph_module - - -class RemoveInputMutation(_pass.Transform): - """Remove `aten.copy_.default` nodes that mutate module inputs. - - This pass is recommended to be used after ``Functionalization`` pass. - ``Functionalization`` pass adds `aten.copy_.default` nodes to the graph - when it detects mutations to inputs. These nodes are not needed for ONNX export - for inference. They could be useful for training. - """ - - def _run(self, *args) -> torch.fx.GraphModule: - for node in reversed(self.module.graph.nodes): - if ( - node.op == "call_function" - and node.target == torch.ops.aten.copy_.default - and len(node.users) == 0 - and isinstance(node.args[0], torch.fx.Node) - and node.args[0].op == "placeholder" - ): - self.module.graph.erase_node(node) - return self.module diff --git a/torch/onnx/_internal/fx/passes/modularization.py b/torch/onnx/_internal/fx/passes/modularization.py deleted file mode 100644 index 18a424826bfe..000000000000 --- a/torch/onnx/_internal/fx/passes/modularization.py +++ /dev/null @@ -1,857 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -import abc -import collections -import copy -import operator -from typing import Any, Final, TYPE_CHECKING - -import torch -import torch.fx -from torch.onnx._internal.fx import _pass -from torch.utils import _pytree as pytree - - -if TYPE_CHECKING: - from collections.abc import Generator, Iterator, Sequence - - -_FX_TRACER_NN_MODULE_META_TYPE = tuple[str, type] -"""Legacy type of item from `node.meta["nn_module_stack"].items()` produced by FX symbolic tracer.""" -_FX_TRACER_NN_MODULE_STACK_META_TYPE = collections.OrderedDict -"""Legacy type of `node.meta["nn_module_stack"]` produced by FX symbolic tracer.""" - -_DYNAMO_NN_MODULE_META_TYPE = tuple[str, tuple[str, type]] -"""Type of item from `node.meta["nn_module_stack"].items()` produced by FX dynamo tracer.""" -_DYNAMO_NN_MODULE_STACK_META_TYPE = dict[str, _DYNAMO_NN_MODULE_META_TYPE] -"""Type of `node.meta["nn_module_stack"]` produced by FX dynamo tracer.""" - - -class _ModuleMeta: - """Meta information about a module. - - This class is used to represent the module information in a more structured way. - It parses raw module information from a single item from - `node.meta["nn_module_stack"].items()`. - - See the uses of `from_raw_meta`, `from_fx_tracer_produced_raw_meta`, and - `from_dynamo_produced_raw_meta` for how to create an instance. - - Attributes: - _module_class: The class of the module. E.g. `torch.nn.module.sparse.Embedding`. - _module_name: The name of the module. E.g. `L__self___h_1_mlp_c_proj`. - _raw_meta: The raw meta '(module_name, node.meta["nn_module_stack"][module_name])'. - """ - - _module_class: Final[type | str | None] # type: ignore[misc] - _module_name: Final[str] # type: ignore[misc] - _raw_meta: Final[tuple[Any, Any]] # type: ignore[misc] - - def __init__( - self, - module_name: str, - module_class: type | str | None, - raw_meta: tuple[Any, Any], - ): - self._module_name = module_name - self._module_class = module_class - self._raw_meta = raw_meta - - @property - def module_display_name(self) -> str: - """The display name of the module. - - E.g. `h_1_mlp_c_proj`. - """ - # E.g., from 'L__self___h_1_mlp_c_proj' to 'h_1_mlp_c_proj'. - name = self.module_name - name = name.removeprefix("L__self___") - return name - - @property - def qualified_module_class_name(self) -> str: - """Qualified name of the module class. - - E.g. `torch_nn_module_sparse_Embedding`. - """ - if self._module_class is None: - return "" - mod_cls = self._module_class - if isinstance(mod_cls, type): - mod_cls = mod_cls.__module__ + "." + mod_cls.__qualname__ - return mod_cls.replace(".", "_") - - @property - def module_class_name(self) -> str: - """Name of the module class. - - E.g. `Embedding`. - """ - if self._module_class is None: - return "" - if isinstance(self._module_class, type): - return self._module_class.__name__ - return self._module_class - - @property - def module_name(self) -> str: - """Name of the module. - - E.g. `L__self___h_1_mlp_c_proj`. - """ - return self._module_name - - @property - def raw_meta(self) -> tuple[Any, Any]: - """Returns the raw module meta data. - - I.e. (module_name, node.meta['nn_module_stack'][module_name]). - """ - return self._raw_meta - - def __eq__(self, other: object, /) -> bool: - if not isinstance(other, _ModuleMeta): - return False - return ( - self._module_name == other._module_name - and self._module_class == other._module_class - ) - - def __hash__(self) -> int: - return hash((self._module_name, self._module_class)) - - def __repr__(self) -> str: - return f"ModuleMeta(name={self._module_name}, class={self._module_class})" - - @classmethod - def create_root(cls) -> _ModuleMeta: - """Create an empty module meta representing root module.""" - return _ModuleMeta("", None, ("", None)) - - @classmethod - def from_fx_tracer_produced_raw_meta( - cls, raw_meta: _FX_TRACER_NN_MODULE_META_TYPE - ) -> _ModuleMeta: - """Create a module meta from raw meta produced by FX symbolic tracer.""" - module_name, module_class = raw_meta - return _ModuleMeta(module_name, module_class, raw_meta) - - @classmethod - def from_dynamo_produced_raw_meta( - cls, raw_meta: _DYNAMO_NN_MODULE_META_TYPE - ) -> _ModuleMeta: - """Create a module meta from raw meta produced by FX dynamo tracer.""" - module_name, (_qualified_name, module_class) = raw_meta - return _ModuleMeta(module_name.split("@")[0], module_class, raw_meta) - - @classmethod - def from_raw_meta( - cls, - raw_meta: _FX_TRACER_NN_MODULE_META_TYPE | _DYNAMO_NN_MODULE_META_TYPE, - ) -> _ModuleMeta: - if ( - isinstance(raw_meta, tuple) - and len(raw_meta) == 2 - and isinstance(raw_meta[1], type) - ): - # Trying to do `instance(raw_meta, _FX_TRACER_NN_MODULE_META_TYPE)` - return _ModuleMeta.from_fx_tracer_produced_raw_meta(raw_meta) - if ( - isinstance(raw_meta, tuple) - and len(raw_meta) == 2 - and isinstance(raw_meta[1], tuple) - ): - # Trying to do `instance(raw_meta, _DYNAMO_NN_MODULE_META_TYPE)` - return _ModuleMeta.from_dynamo_produced_raw_meta(raw_meta) - raise TypeError( - f"Unknown type of raw meta item from node.meta['nn_module_stack'].items(): {type(raw_meta)}" - ) - - -class _ModuleStackMeta: - """Meta information about the module call stack. - - This class is used to represent the module call stack information in a more - structured way. It parses raw module stack information from `node.meta["nn_module_stack"]`. - - Example of raw module stack information: - - If produced by dynamo: - - { - 'L__self___h_1': ( - "L['self'].h[1]", - - ), - 'L__self___h_1_attn': ( - "L['self'].h[1].attn", - - ) - } - - If produced by fx.symbolic_trace: - - { - 'h.1': , - 'h.1.attn': - } - """ - - _module_stack: Final[list[_ModuleMeta]] # type: ignore[misc] - - def __init__( - self, - nn_module_stack_meta: _FX_TRACER_NN_MODULE_STACK_META_TYPE - | _DYNAMO_NN_MODULE_STACK_META_TYPE - | None, - is_exported_program: bool = True, - ): - self._module_stack = [] - if nn_module_stack_meta is None: - return - raw_meta = copy.copy(nn_module_stack_meta) - for item in raw_meta.items(): - # If produced by torch.export.export, there is another call stack layer - # that we need to skip - if is_exported_program: - is_exported_program = False - continue - self.push(_ModuleMeta.from_raw_meta(item)) # type: ignore[arg-type] - - def __len__(self) -> int: - return len(self._module_stack) - - def __getitem__(self, index: int) -> _ModuleMeta: - return self._module_stack[index] - - def __iter__(self) -> Iterator[_ModuleMeta]: - return iter(self._module_stack) - - def is_empty_or_root(self) -> bool: - return len(self._module_stack) == 0 - - def top(self) -> _ModuleMeta: - """Returns the top module meta in the stack. I.e., the meta for leaf module. - - Example: - - Consider the following module stack: - - stack = [GPT, block1, Attention_1, MLP] - - stack.top() == MLP - """ - if self.is_empty_or_root(): - return _ModuleMeta.create_root() - return self._module_stack[-1] - - def is_superset_of( - self, - module_stack: _ModuleStackMeta, - ) -> bool: - """Determines if self is a superset of the provided module stack. - - I.e., If self includes all elements from the provided module stack, plus additional - elements on top. If self is empty or root, this method always return False. - - Example: - - Consider the following module stack: - - stack_1 = [GPT, block1, Attention_1, MLP] - stack_2 = [GPT, block1] - - stack_1.is_superset_of(stack_2) == True - stack_2.is_superset_of(stack_1) == False - - stack_3 = [GPT, block2, Attention_1] - - stack_1.is_superset_of(stack_3) == False - stack_3.is_superset_of(stack_1) == False - """ - if self.is_empty_or_root(): - return False - - if module_stack.is_empty_or_root() is None: - return True - - if len(self) <= len(module_stack): - return False - - for i, parent_key in enumerate(module_stack): - if self[i] != parent_key: - return False - - return True - - def push(self, module_meta: _ModuleMeta) -> None: - """Pushes a module meta to the stack.""" - self._module_stack.append(module_meta) - - def __eq__(self, other: object, /) -> bool: - if not isinstance(other, _ModuleStackMeta): - return False - return self._module_stack == other._module_stack - - @property - def raw_meta(self) -> dict[str, tuple[str, type]] | None: - """Returns the raw module stack meta data, i.e. node.meta['nn_module_stack'].""" - return { - module_meta.raw_meta[0]: module_meta.raw_meta[1] - for module_meta in self._module_stack - } - - def __repr__(self) -> str: - return f"ModuleStackMeta({self._module_stack})" - - @property - def module_display_name(self) -> str: - """Returns the module display name of the top module.""" - return self.top().module_display_name - - @property - def qualified_module_class_name(self) -> str: - """Returns the qualified module class name of the top module.""" - return self.top().qualified_module_class_name - - @property - def module_class(self) -> type | str | None: - """Returns the module class of the top module.""" - return self.top()._module_class - - -def _module_stack_meta_from_node( - node: torch.fx.Node, is_exported_program: bool = False -) -> _ModuleStackMeta: - return _ModuleStackMeta( - node.meta.get("nn_module_stack"), is_exported_program=is_exported_program - ) - - -def _get_unique_module_name(module_names: dict[str, int], module_name: str) -> str: - module_names.setdefault(module_name, 0) - module_names[module_name] += 1 - return f"{module_name}_{module_names[module_name]}" - - -class _IRNode(abc.ABC): - """Base class for IR nodes. - - IR nodes are used for Modularize pass only. They add a layer of abstraction on top of - torch.fx.Node. - - [NOTE: Modularize Pass Implementation] - The main job of the pass is to group `fx.Node`s that belong to the same `nn.Module` - forward call, and then create `call_module` node and sub `fx.GraphModule` from them. - Each `fx.Node` possesses an `nn_module_stack` meta data that contains information - about the module call stack. See `_ModuleStackMeta` for examples. - - Analysis step - ------------- - - Each module call is identified by a set of base stack layers. For each module call, - the pass creates a `_ModuleNode` and groups the sequence of nodes that shares the - same base stack layers. - - For example, - - stack_of_node_0 = [GPT, block0] - stack_of_node_1 = [GPT, block1] - stack_of_node_2 = [GPT, block1, Attention1, MLP] - stack_of_node_3 = [GPT, block1, Attention1] - stack_of_node_4 = [GPT, block2] - - All nodes belong to the `GPT` module call, since they share the base stack layers [GPT]. - [node_1, node_2, node_3] are grouped for `GPT.block1`, because they share the base - stack layers [GPT, block1]. And [node_2, node_3] for `GPT.block1.Attention1`, [node_0] - for `GPT.block0`, and [node_4] for `GPT.block2` respectfully. - - After the analysis step, a hierarchical representation is generated. - - For above example, the representation is: - - _ModuleNode(GPT) - _ModuleNode(block0) - _LeafNode(node_0) - _ModuleNode(block1) - _LeafNode(node_1) - _ModuleNode(Attention1) - _ModuleNode(MLP) - _LeafNode(node_2) - _LeafNode(node_3) - _ModuleNode(block2) - _LeafNode(node_4) - - Construction step - ----------------- - - The second step is to build the actual `call_module` node and the sub `fx.GraphModule`. - This is done recursively from the leaf `_ModuleNode` to the root. - - For example, the first submodule to be built is `GPT.block1.Attention1.MLP`. Below pair - is generated from `_ModuleNode(MLP)`. - - fx.GraphModule(GPT.block1.Attention1.MLP) - graph: - node_2 - - new_mlp_node = `call_module[GPT.block1.Attention1.MLP](...)` - - Next, the `GPT.block1.Attention1` submodule is built. Below is generated from - `_ModuleNode(Attention1)`. - - fx.GraphModule(GPT.block1.Attention1) - graph: - new_mlp_node - node_3 - - new_attention1_node = `call_module[GPT.block1.Attention1](...)` - - Until every submodule is built, the new modularized `fx.GraphModule` is generated. - - Alternatives - ------------ - - The current algorithm adopts a top down approach. A bottom up approach is similar. - In contrast to these two, an alternative flat order approach is also possible, where - each node is traversed and copied to the corresponding submodule. - - The advantage of the current approach lies in the encapsulation of the fx.GraphModule - construction for each individual submodule within a single `build_module` method, which - can be called separately once the analysis phase is completed, making debugging more - convenient. - - Regarding construction step, an alternative implementation is to utilize `fx.Interpreter` - for traversing all the nodes under the flattened root module and copying the nodes - into their respective submodule under construction. This approach is not adopted because - - 1. It uses the flat order approach discussed above. This means one cannot individually - construct a submodule and examine it while debugging. - - 2. The graph execution functionality of `fx.Interpreter` is not necessary for the - purpose of this pass. Ignoring that, `fx.Interpreter.run` achieves the same effect - as a for loop over all the nodes. - """ - - @property - @abc.abstractmethod - def stack_meta(self) -> _ModuleStackMeta: - """The module stack meta data associated with this node.""" - ... - - @property - @abc.abstractmethod - def stack_trace(self) -> str | None: - """The stack trace associated with this node.""" - ... - - -class _ModuleNode(_IRNode): - """Representing a sequence of fx.Nodes to be formed into a fx.GraphModule. - - This class encapsulates metadata and provides building block methods to construct this - layered abstraction from a sequence of flat fx.Nodes. - - Attributes: - - _stack_meta: Metadata of the module stack. - - _nodes: List of IR nodes in the module. - - _reference_root_module: Reference to the root flat fx.GraphModule instance. - """ - - def __init__( - self, reference_root_module: torch.fx.GraphModule, stack_meta: _ModuleStackMeta - ): - self._stack_meta = stack_meta - self._nodes: list[_IRNode] = [] - self._reference_module = reference_root_module - - @property - def stack_meta(self) -> _ModuleStackMeta: - return self._stack_meta - - @property - def stack_trace(self) -> str | None: - assert self._nodes - return self._nodes[0].stack_trace - - def __str__(self) -> str: - return f"ModuleNode({self._stack_meta})" - - def is_same_module_as(self, node: _IRNode) -> bool: - """Determines if the provided node pertains to the same module as this node.""" - return self.stack_meta == node.stack_meta - - def is_parent_module_of(self, node: _IRNode) -> bool: - """Determines if this node represents a parent module of the provided node.""" - return node.stack_meta.is_superset_of(self.stack_meta) - - def add_leaf_node(self, leaf_node: _LeafNode) -> None: - """Adds a leaf node to the module. - - The leaf node must belong to the same or a child module. This method will recursively - construct _ModuleNode instance based on the stack_meta information of the leaf node. - """ - if self.is_same_module_as(leaf_node) or leaf_node.fx_op == "call_module": - self._nodes.append(leaf_node) - elif leaf_node.fx_op == "placeholder": - # Although the original placeholder has empty nn_module_stack, the placeholder lifted - # from get_attr nodes by exported program has their original nn_module_stack. Here - # we need to avoid them building submodule. - self._nodes.append(leaf_node) - elif self.is_parent_module_of(leaf_node): - # This node belongs in a submodule. - # Check if the last node is a submodule and if it is the parent of this node. - last_node = self._nodes[-1] if self._nodes else None - if isinstance(last_node, _ModuleNode) and ( - last_node.is_parent_module_of(leaf_node) - or last_node.is_same_module_as(leaf_node) - ): - # This node belongs to the last_node. - last_node.add_leaf_node(leaf_node) - else: - # Create a new SubmoduleNode for the immediate child module of the current - # module. The leaf node may be a grandchild of the current module. - # Example: - # self.stack_meta = [A, B, C] - # leaf_node.stack_meta = [A, B, C, D, E, F] - # Create a new ModuleNode with stack_meta = [A, B, C, D] and add leaf_node to it. - stack_meta = copy.deepcopy(self.stack_meta) - stack_meta.push(leaf_node.stack_meta[len(self.stack_meta)]) - last_node = _ModuleNode( - self._reference_module, - stack_meta, - ) - self._nodes.append(last_node) - last_node.add_leaf_node(leaf_node) - else: - raise AssertionError( - f"Node {leaf_node} ({leaf_node.stack_meta}) does not belong to module " - f"{self._stack_meta}." - ) - - def fx_nodes(self) -> Generator[torch.fx.Node, None, None]: - """Returns an iterator for the sequence of fx nodes this instance holds.""" - for node in self._nodes: - if isinstance(node, _ModuleNode): - yield from node.fx_nodes() - else: - assert isinstance(node, _LeafNode) - yield node.fx_node - - def module_inputs(self) -> Sequence[torch.fx.Node]: - """Extract module inputs from the sequence of fx nodes this instance holds. - - All node args that are produced by nodes outside of the module are considered module - inputs. The order of returned module inputs is the same as the their use order. - - ### Known limitations - - The original ordering of module inputs is not preserved. There is no meta information - to be found from the `fx.GraphModule` that can be used to recover the original ordering. - - Returns: - Sequence of module inputs. - """ - nodes = list(self.fx_nodes()) - assert len(nodes) > 0, "Cannot extract module inputs from empty nodes." - module_inputs: dict[torch.fx.Node, None] = {} - node_set: set[torch.fx.Node] = set(nodes) - - def _extract_arg_if_node_outside_module(arg: Any): - if isinstance(arg, torch.fx.Node) and arg not in node_set: - module_inputs[arg] = None - - for node in nodes: - pytree.tree_map(_extract_arg_if_node_outside_module, node.args) - pytree.tree_map(_extract_arg_if_node_outside_module, node.kwargs) - return list(module_inputs.keys()) - - def module_outputs(self) -> Sequence[torch.fx.Node]: - """Extract module outputs from the sequence of fx nodes this instance holds. - - All nodes that are used by nodes outside of the module are considered module - outputs. The order of returned module outputs is the same as the their creation order. - - ### Known limitations - - The original ordering of module outputs is not preserved. There is no meta information - to be found from the `fx.GraphModule` that can be used to recover the original ordering. - - Returns: - Sequence of module outputs. - """ - nodes = list(self.fx_nodes()) - assert len(nodes) > 0, "Cannot extract module inputs from empty nodes." - # Need ordered set. Emulate with dict. - module_outputs: dict[torch.fx.Node, None] = {} - node_set: set[torch.fx.Node] = set(nodes) - - for node in nodes: - if any(user not in node_set for user in node.users): - module_outputs[node] = None - return list(module_outputs.keys()) - - def build_module(self, module_names: dict[str, int]) -> torch.fx.GraphModule: - """ - Constructs the fx.GraphModule for this node, registering submodules as necessary. - - Args: - module_names: A dictionary of module names and their counts. This is used to - generate unique module names for submodules. This should be an empty - dictionary when the method is called on a root module. - """ - module_class_name = self._stack_meta.qualified_module_class_name - fx_graph = torch.fx.Graph() - copy_env: dict[torch.fx.Node, torch.fx.Node] = {} - - def _arg_transform(node: torch.fx.Node) -> torch.fx.Node: - return copy_env[node] - - ref_inputs = self.module_inputs() - for node in ref_inputs: - copy_env[node] = fx_graph.placeholder(node.name, node.type) - copy_env[node].meta = copy.copy(node.meta) - - for ir_node in self._nodes: - if isinstance(ir_node, _LeafNode): - fx_node = ir_node.fx_node - copy_env[fx_node] = fx_graph.node_copy( - fx_node, arg_transform=_arg_transform - ) - continue - - assert isinstance(ir_node, _ModuleNode) - # Create fx.GraphModule for child submodule. - submodule = ir_node.build_module(module_names) - ref_submodule_inputs = ir_node.module_inputs() - ref_submodule_outputs = ir_node.module_outputs() - unique_submodule_name = _get_unique_module_name( - module_names, ir_node.stack_meta.module_display_name - ) - # Link the newly generated sub fx.GraphModule with the root reference module. - # This step is essential to meet the needs of the subsequent fx.GraphModule initialization - # for the fx.GraphModule being created by this method. - # The initialization of fx.GraphModule will replicate all necessary attributes from a reference - # fx.GraphModule for the fx.Graph. While the root reference module possesses all - # parameters and buffers, it does not include the newly created sub fx.GraphModule. - # Therefore, it's necessary to register it under the root reference at this stage. - self._reference_module.add_submodule(unique_submodule_name, submodule) - - # create call_module fx.Node - submodule_node = fx_graph.call_module( - unique_submodule_name, - tuple(_arg_transform(node) for node in ref_submodule_inputs), - ) - if len(ref_submodule_outputs) > 1: - # Module node has multiple output. Create 'getitem' node for each output. - submodule_node.meta["val"] = tuple( - ref_output.meta.get("val") for ref_output in ref_submodule_outputs - ) - for i, ref_output in enumerate(ref_submodule_outputs): - getitem_node = fx_graph.call_function( - operator.getitem, - args=(submodule_node, i), - type_expr=ref_output.type, - ) - getitem_node.meta = copy.copy(ref_output.meta) - # Make a copy for "nn_module_stack" since the current module will be - # popped from the stack for this 'getitem' node. - getitem_node.meta["nn_module_stack"] = copy.copy( - ref_output.meta["nn_module_stack"] - ) - # The node is associated with the parent module. - getitem_node.meta["nn_module_stack"].popitem() - copy_env[ref_output] = getitem_node - else: - # Module node has single output. Use module node directly. - copy_env[ref_submodule_outputs[0]] = submodule_node - submodule_node.meta = copy.copy(ref_submodule_outputs[0].meta) - - # Update meta for new call_module node. - if (stack_trace := ir_node.stack_trace) is not None: - submodule_node.meta["stack_trace"] = stack_trace - raw_module_stack_meta = ir_node.stack_meta.raw_meta - assert raw_module_stack_meta is not None - submodule_node.meta["nn_module_stack"] = copy.copy(raw_module_stack_meta) - # The node is associated with the parent module. - submodule_node.meta["nn_module_stack"].popitem() - - new_nodes = fx_graph.nodes - # Skip if the last node is already 'output'. This is the case for root module. - # Otherwise create an 'output' node for the inferred outputs. - if next(iter(reversed(new_nodes))).op != "output": - ref_submodule_outputs = self.module_outputs() - new_outputs = [copy_env[ref_output] for ref_output in self.module_outputs()] - node = fx_graph.output( - new_outputs[0] if len(new_outputs) == 1 else new_outputs - ) - - graph_module = torch.fx.GraphModule( - self._reference_module, fx_graph, module_class_name - ) - if (module_class := self._stack_meta.module_class) is not None: - graph_module.meta["onnx"] = _pass.GraphModuleOnnxMeta( - _pass.PackageInfo.from_python_class(module_class) - ) - return graph_module - - -class _LeafNode(_IRNode): - """Representing a single fx.Node.""" - - def __init__(self, node: torch.fx.Node, is_exported_program: bool = False): - self._node = node - self._stack_meta = _module_stack_meta_from_node( - node, is_exported_program=is_exported_program - ) - - @property - def fx_op(self) -> str: - """Syntax sugar for self.fx_node.op.""" - return self._node.op - - @property - def fx_node(self) -> torch.fx.Node: - """Returns the fx.Node this instance represents.""" - return self._node - - @property - def stack_meta(self) -> _ModuleStackMeta: - """Returns the module stack meta data associated with this node.""" - return self._stack_meta - - @property - def stack_trace(self) -> str | None: - """Returns the stack trace associated with this node.""" - return self.fx_node.meta.get("stack_trace") - - def __str__(self) -> str: - return f"LeafNode({self._node})" - - -class Modularize(_pass.Transform): - """Transforms a flattened `fx.GraphModule` into a modular structure. - - In the flattened `fx.GraphModule`, each `nn.Module` forward call has been traced as - a sequence of `fx.Node`s. All these `fx.Node`s are flattened and reside in the same - `fx.GraphModule`. `fx.GraphModule` could be from `torch.export.ExportedProgram` or - directly generated by `torch._dynamo.export` with torch.nn.Module. - - This pass generates a new `fx.GraphModule`. It groups the flattened `fx.Node`s that belong - to the same `nn.Module` forward call into a sub `fx.GraphModule`. It then replaces the - sequence of flattened `fx.Node`s with a single `call_module` node, which is linked with - the sub `fx.GraphModule` by `node.target`. The sub `fx.GraphModule` is registered as a - submodule of the new `fx.GraphModule`. - - The process is done based on information from the `nn_module_stack` metadata of each node, i.e. - `node.meta["nn_module_stack"]`. For more implementation details, see [NOTE: Modularize Pass Implementation]. - - An fx submodule under this context can typically be interpreted in three different ways: - - 1. As an embodiment of an nn.Module class, which is considered stateless. - Its execution path can vary depending on the configuration of module initialization, - which should also be part of the inputs. - - 2. As a representation of an nn.Module instance. It maintains the state initialized in the module. - The execution path can vary based on actual input data. - - 3. As a captured call of an nn.Module instance, where the execution path - is set. - - The generality decreases along this list. Within the scope of this function, the pass - creates fx submodules according to the third interpretation. - - The first interpretation is the most general case. It requires complex analysis and additional - metadata and code information to construct its general form. Consider an example nn.Module - that generates arbitrary submodules based on an initialization configuration file. It's impractical - to extract this logic for the generated fx submodule to function with arbitrary configuration. - - The second interpretation demands less analysis and is sturdier than the - first. In most use cases, it's equivalent to the third. It only differs in exceptional situations - where a complex nn.Module instance is called multiple times, each with a different set of inputs - leading to a unique execution branching path. - - The third interpretation is the most specific scenario. It necessitates the minimum - analysis and creates the most stable representation. The drawback is that it - generates more redundancy than the other two methods. If needed, a subsequent post-processing - pass can be applied to consolidate completely identical functions and reduce duplication. - - ### Known constraints - Two successive calls to the same module instance will be conflated. They are indistinguishable. - This is due to limitations of the current fx metadata "nn_module_stack". - - [NOTE: Modularize pass ordering] - This pass groups fx nodes into subgraphs that reside within the `call_module` fx node. - Other fx passes (including some outside the exporter) might not recognize `call_module`. - They may assume that all nodes are flattened. Hence it is recommended to invoke this pass - as the last pre onnx export fx pass. If not for this consideration, this operation could - potentially be relocated anywhere earlier in the pipeline. - - Example: - - >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) - >>> import torch - >>> from torch.onnx._internal.fx import passes - >>> - >>> class CustomModule(torch.nn.Module): - >>> def __init__(self) -> None: - >>> super().__init__() - >>> self.embedding = torch.nn.Embedding(10, 32) - >>> self.relu = torch.nn.ReLU() - >>> - >>> def forward(self, x): - >>> out = self.embedding(x) - >>> out = self.relu(out) - >>> return out - >>> - >>> class TestModule(torch.nn.Module): - >>> def __init__(self) -> None: - >>> super().__init__() - >>> self.layer = CustomModule() - >>> self.linear = torch.nn.Linear(32, 10) - >>> - >>> def forward(self, x): - >>> out = self.layer(x) - >>> out = self.linear(out) - >>> return out - >>> - >>> gm, _ = torch._dynamo.export(TestModule(), aten_graph=True)( - ... torch.tensor([0, 1, 2]) - ... ) - >>> gm.print_readable() - - >>> gm = passes.Modularize( - ... gm, - ... ).run() - >>> gm.print_readable() - - """ - - def __init__( - self, - module: torch.fx.GraphModule, - is_exported_program: bool = False, - ): - super().__init__(module) - self.module = module - self.is_exported_program = is_exported_program - - def _run(self) -> torch.fx.GraphModule: - # DCE to remove unused nodes. - # If a submodule is unused, it is hard to analyze which nodes constitutes the submodule - # outputs. But since it is unused, we can just remove it. - self.module.graph.eliminate_dead_code() - - reference_module = torch.fx.GraphModule(self.module, self.module.graph) - root_module_node = _ModuleNode( - reference_module, - _ModuleStackMeta( - nn_module_stack_meta=None, is_exported_program=self.is_exported_program - ), - ) - for fx_node in self.module.graph.nodes: - root_module_node.add_leaf_node( - _LeafNode(fx_node, is_exported_program=self.is_exported_program) - ) - return root_module_node.build_module({}) diff --git a/torch/onnx/_internal/fx/passes/readability.py b/torch/onnx/_internal/fx/passes/readability.py deleted file mode 100644 index a14d07b9aa19..000000000000 --- a/torch/onnx/_internal/fx/passes/readability.py +++ /dev/null @@ -1,130 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING - -import torch -from torch.onnx._internal.fx import _pass - - -if TYPE_CHECKING: - from collections.abc import Sequence - - -logger = logging.getLogger(__name__) - - -class RestoreParameterAndBufferNames(_pass.Transform): - """Restore parameter and buffer names from original nn.module. - - This pass is useful for readability of the exported ONNX graph. It restores the - parameter and buffer names from the original nn.module. For example, if the original - nn.module has a parameter named `root.linear.0.weight`, and the parameter is renamed to - `_param_constant9` by FX, this pass will rename it back. - - This pass must be run after `Decompose` pass. Because this pass is expected to be called on - `fx.GraphModule` produced by `proxy_tensor.make_fx`, where all parameters and buffers - are registered at root level. - """ - - def __init__( - self, - fx_module: torch.fx.GraphModule, - original_nn_module: torch.nn.Module, - ): - super().__init__(fx_module) - self.original_nn_module = original_nn_module - - def _rename_param_and_buffer( - self, - nodes: Sequence[torch.fx.Node], - new_name: str, - ) -> None: - """Rename the parameter/buffer and replace corresponding nodes with new nodes of updated target.""" - assert len(nodes) > 0, "`nodes` cannot be empty" - assert len({node.target for node in nodes}) == 1, ( - "`nodes` must all have same `target`" - ) - old_name = nodes[0].target - assert isinstance(old_name, str), f"Expected str, got type({old_name})" - # Parameter/buffer name cannot contain "." - normalized_name = new_name.replace(".", "/") - attr_value = getattr(self.module, old_name) - setattr(self.module, normalized_name, attr_value) - delattr(self.module, old_name) - for node in nodes: - with self.module.graph.inserting_before(node): - new_node = self.module.graph.get_attr(normalized_name) - new_node.meta = node.meta - node.replace_all_uses_with(new_node) - self.module.graph.erase_node(node) - logger.info( - "Renamed 'self.%s' to 'self.%s', " - "normalized from original parameter name '%s'.", - old_name, - normalized_name, - new_name, - ) - - def _run(self, *args, **kwargs) -> torch.fx.GraphModule: - """Restore parameter and buffer names from original module. - - For each `get_attr` node, if the target is a str representing a parameter or buffer - under `self.module`, we rename the parameter or buffer to its original name. - The parameters and buffers between `self.module` and `self.original_nn_module` refer - to the same objects, allowing us to use it as key to retrieve the original name. - """ - assert len(args) == 0, "RestoreParameterAndBufferNames does not take any args" - assert len(kwargs) == 0, ( - "RestoreParameterAndBufferNames does not take any kwargs" - ) - # state_to_readable_name[parameter/buffer] returns the original readable name of - # the parameter/buffer. E.g., "self.linear.weight". - state_to_readable_name: dict[torch.nn.Parameter | torch.Tensor, str] = {} - state_to_readable_name.update( - {v: k for k, v in self.original_nn_module.named_parameters()} - ) - state_to_readable_name.update( - {v: k for k, v in self.original_nn_module.named_buffers()} - ) - - # old_name_to_nodes[old_name] returns a tuple of (nodes, new_name) - # where `nodes` is a list of `get_attr` nodes with `old_name` as `target` and - # `new_name` is the new readable name. - old_name_to_nodes: dict[str, tuple[list[torch.fx.Node], str]] = {} - - for node in self.module.graph.nodes: - if node.op == "get_attr": - assert isinstance(node.target, str), ( - f"Expected str, got type({node.target})" - ) - if node.target.find(".") != -1: - raise RuntimeError( - f"Unexpected target {node.target} in get_attr, found '.' in target. " - f"All parameters and buffers are expected to be registered at root level, " - f"i.e., self.module. " - ) - if node.target in old_name_to_nodes: - # We have already processed this parameter/buffer. - old_name_to_nodes[node.target][0].append(node) - continue - attr_value = getattr(self.module, node.target) - if ( - isinstance(attr_value, (torch.nn.Parameter, torch.Tensor)) - and attr_value in state_to_readable_name - ): - readable_name = state_to_readable_name[attr_value] - old_name_to_nodes[node.target] = ([node], readable_name) - continue - - logger.info( - "Cannot find readable name for self.%s: %s. The name is unchanged.", - node.target, - type(attr_value), - ) - - for nodes, new_name in old_name_to_nodes.values(): - self._rename_param_and_buffer(nodes, new_name) - - return self.module diff --git a/torch/onnx/_internal/fx/passes/virtualization.py b/torch/onnx/_internal/fx/passes/virtualization.py deleted file mode 100644 index 504dea1d8424..000000000000 --- a/torch/onnx/_internal/fx/passes/virtualization.py +++ /dev/null @@ -1,96 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -from typing import TYPE_CHECKING - -import torch -from torch.onnx._internal.fx import _pass - - -if TYPE_CHECKING: - import torch.fx - - -class MovePlaceholderToFront(_pass.Transform): - """This pass move all placeholder nodes to the front of the graph node list. - - In torch.fx.Graph, placeholder is a special assignment node. If it's not - executed in the beginning, it could overwrite values computed by upstream - nodes. - """ - - def _run(self, *args, **kwargs) -> torch.fx.GraphModule: - graph_module = self.module - graph = graph_module.graph - placeholders = [] - first_not_placeholder = None - for node in graph.nodes: - if node.op == "placeholder": - placeholders.append(node) - if first_not_placeholder is None and node.op != "placeholder": - first_not_placeholder = node - if first_not_placeholder is None: - return graph_module - for placeholder in placeholders: - first_not_placeholder.prepend(placeholder) - return graph_module - - -class ReplaceGetAttrWithPlaceholder(_pass.Transform): - """Replace get_attr with placeholder. - - The parameters and buffers accessed by the original get_attr are returned; - they are useful when creating random inputs for the modified graph_module. - """ - - _replaced_attrs: tuple[torch.Tensor, ...] | None - - @property - def replaced_attrs(self) -> tuple[torch.Tensor, ...]: - """The list of replaced weight tensors.""" - assert self._replaced_attrs is not None, ( - "Must run ReplaceGetAttrWithPlaceholder first" - ) - return self._replaced_attrs - - def _run(self, *args, **kwargs) -> torch.fx.GraphModule: - graph_module = self.module - graph = graph_module.graph - replaced_attrs: list[torch.Tensor] = [] - for node in graph.nodes: - if node.op == "get_attr": - replaced_attr: torch.Tensor | None = None - # get_attr could retrieve either parameter or buffer, so - # we need to try both. - try: - replaced_attr = graph_module.get_parameter(node.target) - except AttributeError: - # It's possible that model author use buffer instead of - # parameter to store trainable weights. In this case, - # 1. get_parameter will throw something like - # AttributeError: `bias` is not an nn.Parameter. - # 2. get_buffer should work. - replaced_attr = graph_module.get_buffer(node.target) - - # Reassign op type so that get_attr node becomes placeholder node. - node.op = "placeholder" - # The target name in placeholder must be a valid Python identifier. - # Thus, we replace, e.g., "module.submodule.weight" with - # "module_submodule_weight". - node.target = node.target.replace(".", "_") - # Default value is None. This is needed as long as the "graph_module" - # has optional inputs. Assume the original forward signature is - # def forward(self, x, y=None) - # and the replaced get_attr node has target "z". Then, the modified - # signature should be - # def forward(self, x, y=None, z=None) - # Without the following line, the signature will be - # def forward(self, x, y=None, z) - # , which is not valid Python code. - node.args = (None,) - - replaced_attrs.append(replaced_attr) - - self._replaced_attrs = tuple(replaced_attrs) - - return graph_module diff --git a/torch/onnx/_internal/fx/registration.py b/torch/onnx/_internal/fx/registration.py deleted file mode 100644 index ec6fc638e3f2..000000000000 --- a/torch/onnx/_internal/fx/registration.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Module for handling ATen to ONNX functions registration.""" - -from __future__ import annotations - -import dataclasses -from typing import TYPE_CHECKING - - -# We can only import onnx from this module in a type-checking context to ensure that -# 'import torch.onnx' continues to work without having 'onnx' installed. We fully -# 'import onnx' inside of dynamo_export (by way of _assert_dependencies). -if TYPE_CHECKING: - import types - - import onnxscript # type: ignore[import] - - import torch._ops - - -@dataclasses.dataclass(frozen=True, eq=True) -class ONNXFunction: - """A wrapper of onnx-script function. - - op_full_name: The qualified name of the function. In the form of '::.'. - onnx_function: The onnx-script function from torchlib. - is_custom: Whether the function is a custom function. - is_complex: Whether the function is a function that handles complex valued inputs. - - """ - - onnx_function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction - op_full_name: str - is_custom: bool = False - is_complex: bool = False - - -@dataclasses.dataclass(frozen=True, eq=True) -class OpName: - """A class representing an operator name in internal ONNX converter.""" - - namespace: str - op_name: str - overload: str - - @classmethod - def from_name_parts( - cls, namespace: str, op_name: str, overload: str | None = None - ) -> OpName: - # NOTE: in PyTorch, the overload could be unprovided to indicate the - # default overload - if overload is None or overload == "": - overload = "default" - return cls(namespace, op_name, overload) - - @classmethod - def from_qualified_name(cls, qualified_name: str) -> OpName: - """When the name is ::[.]""" - namespace, opname_overload = qualified_name.split("::") - op_name, *overload = opname_overload.split(".", 1) - overload = overload[0] if overload else "default" - return cls(namespace, op_name, overload) - - @classmethod - def from_op_overload(cls, op_overload: torch._ops.OpOverload) -> OpName: - return cls.from_qualified_name(op_overload.name()) - - @classmethod - def from_builtin_function( - cls, builtin_function: types.BuiltinFunctionType - ) -> OpName: - """From a builtin function, e.g. operator.add, math.ceil, etc, get the OpName. - - FX graph uses built-in functions to calculate sympy expression. This function - is used to get the OpName from a builtin function. - - Args: - builtin_function (types.BuiltinFunctionType): operator.add, math.ceil, etc. - - Returns: - OpName: _description_ - """ - op = builtin_function.__name__ # add, sub, etc. - module = builtin_function.__module__ # _operators or math - return cls.from_qualified_name(module + "::" + op) - - def qualified_name(self) -> str: - return f"{self.namespace}::{self.op_name}.{self.overload}" diff --git a/torch/onnx/_internal/io_adapter.py b/torch/onnx/_internal/io_adapter.py deleted file mode 100644 index 6c414e8d54e7..000000000000 --- a/torch/onnx/_internal/io_adapter.py +++ /dev/null @@ -1,652 +0,0 @@ -# mypy: allow-untyped-defs -from __future__ import annotations - -from typing import Any, Callable, TYPE_CHECKING -from typing_extensions import Protocol, runtime_checkable - -import torch -import torch.export as torch_export -from torch.utils import _pytree as pytree - - -if TYPE_CHECKING: - import inspect - from collections.abc import Mapping, Sequence - - -@runtime_checkable -class InputAdaptStep(Protocol): - """A protocol that defines a step in the input adapting process. - - The input adapting process is a sequence of steps that are applied to the - PyTorch model inputs to transform them into the inputs format expected by the - exported ONNX model. Each step takes the PyTorch model inputs as arguments and - returns the transformed inputs. - - This serves as a base formalized construct for the transformation done to model - input signature by any individual component in the exporter. - """ - - def apply( - self, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> tuple[Sequence[Any], Mapping[str, Any]]: ... - - -class InputAdapter: - """A class that adapts the PyTorch model inputs to exported ONNX model inputs format.""" - - def __init__(self, steps: list[InputAdaptStep] | None = None): - self._steps = steps or [] - - def append_step(self, step: InputAdaptStep) -> None: - """Appends a step to the input adapt steps. - - Args: - step: The step to append. - """ - self._steps.append(step) - - def apply( - self, - *model_args, - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - **model_kwargs, - ) -> Sequence[int | float | bool | str | torch.Tensor | torch.dtype | None]: - """Converts the PyTorch model inputs to exported ONNX model inputs format. - - Args: - model_args: The PyTorch model inputs. - model: The PyTorch model. - model_kwargs: The PyTorch model keyword inputs. - Returns: - A sequence of tensors converted from PyTorch model inputs. - """ - args: Sequence[Any] = model_args - kwargs: Mapping[str, Any] = model_kwargs - for step in self._steps: - args, kwargs = step.apply(args, kwargs, model=model) - assert not kwargs - return args - - -@runtime_checkable -class OutputAdaptStep(Protocol): - """A protocol that defines a step in the output adapting process. - - The output adapting process is a sequence of steps that are applied to the - PyTorch model outputs to transform them into the outputs format produced by the - exported ONNX model. Each step takes the PyTorch model outputs as arguments and - returns the transformed outputs. - - This serves as a base formalized construct for the transformation done to model - output signature by any individual component in the exporter. - """ - - def apply( - self, - model_outputs: Any, - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> Any: ... - - -class OutputAdapter: - """A class that adapts the PyTorch model outputs to exported ONNX model outputs format.""" - - def __init__(self, steps: list[OutputAdaptStep] | None = None): - self._steps = steps or [] - - def append_step(self, step: OutputAdaptStep) -> None: - """Appends a step to the output format steps. - - Args: - step: The step to append. - """ - self._steps.append(step) - - def apply( - self, - model_outputs: Any, - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> Sequence[torch.Tensor | int | float | bool | str]: - """Converts the PyTorch model outputs to exported ONNX model outputs format. - - Args: - model_outputs: The PyTorch model outputs. - model: The PyTorch model. - - Returns: - PyTorch model outputs in exported ONNX model outputs format. - """ - for step in self._steps: - model_outputs = step.apply(model_outputs, model=model) - return model_outputs - - -# TODO: make_fx lose stack info https://github.com/pytorch/pytorch/issues/90276 - - -# TODO(XuehaiPan): Dynamo does not support `dummy_leaf = object()` as a sentinel value in the frame. -class _DummyLeaf: # use a class instead. - pass - - -def _replace_list_with_tuple(spec: pytree.TreeSpec) -> pytree.TreeSpec: - def replace_list_with_tuple(x: Any) -> Any: - if type(x) is list: - return pytree.tree_map( - replace_list_with_tuple, - tuple(x), - is_leaf=lambda x: type(x) is list, - ) - return x - - dummy_leaf = _DummyLeaf() - dummy_tree = pytree.tree_unflatten([dummy_leaf] * spec.num_leaves, spec) - dummy_tree = pytree.tree_map( - replace_list_with_tuple, - dummy_tree, - is_leaf=lambda x: type(x) is list, - ) - return pytree.tree_structure(dummy_tree) - - -def _open_top_level_sequence_if_single_element( - spec: pytree.TreeSpec, -) -> pytree.TreeSpec: - if spec.type in (tuple, list) and spec.num_children == 1: - return spec.children_specs[0] - return spec - - -def _assert_identical_pytree_spec( - spec1: pytree.TreeSpec, spec2: pytree.TreeSpec, error_message: str -) -> None: - """Assert the two `TreeSpec` objects are identical. - - Args: - spec1: The first `TreeSpec` object. - spec2: The second `TreeSpec` object. - error_message: The error message to raise if the two `TreeSpec` objects are not - identical. - - Raises: - ValueError: If the two `TreeSpec` objects are not identical. - """ - pass_if_any_checks: Sequence[Callable[[], bool]] = [ - lambda: spec1 == spec2, - # FIXME: Bug in `dynamo.export`. Sometimes outputs returned in 'list' instead of 'tuple'. - lambda: _replace_list_with_tuple(spec1) == _replace_list_with_tuple(spec2), - # FIXME: Bug in `dynamo.export`. Sometimes single function return is wrapped in list. - lambda: _open_top_level_sequence_if_single_element(spec1) == spec2, - lambda: spec1 == _open_top_level_sequence_if_single_element(spec2), - ] - - if not any(check() for check in pass_if_any_checks): - raise ValueError(f"{error_message}\nExpect {spec1}.\nActual {spec2}.") - - -class BindInputStep(InputAdaptStep): - """Bind the input arguments to the model signature.""" - - def __init__(self, model_signature: inspect.Signature): - self._model_signature = model_signature - - def apply( - self, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> tuple[Sequence[Any], Mapping[str, Any]]: - """Bind the input arguments to the model signature. - - We hope the input kwargs will be mapped to bound.args after binding. - If not, we will raise an error. - - Args: - model_args: The model args. - model_kwargs: The model kwargs. - model: The PyTorch model. - - Returns: - A tuple of the model args and kwargs. args is always empty. - - Raises: - ValueError: If there are keyword-only arguments left after binding args and - kwargs to model signature. - """ - bound = self._model_signature.bind(*model_args, **model_kwargs) - bound.apply_defaults() - - # keyword-only arguments are not handled. - # bound.kwargs only contains keyword-only arguments after calling - # bind & apply_defaults, so we raise if it's not empty. - if bound.kwargs: - raise ValueError("Keyword-only arguments are not supported.") - return (), bound.arguments - - -class MergeKwargsIntoArgsInputStep(InputAdaptStep): - """Merge the input kwargs into the input args.""" - - def apply( - self, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> tuple[Sequence[Any], Mapping[str, Any]]: - """Merge the input kwargs into the input args. - - Args: - model_args: The model args. - model_kwargs: The model kwargs. - model: The PyTorch model. - - Returns: - A tuple of the model args and kwargs. kwargs is always empty. - """ - return tuple(model_args) + tuple(model_kwargs.values()), {} - - -class LiftParametersAndBuffersIntoArgsInputStep(InputAdaptStep): - """Append parameters and buffers to model's positional argument list.""" - - def __init__(self, inputs: tuple[torch.Tensor, ...]) -> None: - self.inputs = inputs - - def apply( - self, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> tuple[Sequence[Any], Mapping[str, Any]]: - """Append model's parameters and buffers into its input. - - Args: - model_args: The model args. - model_kwargs: The model kwargs. - model: The PyTorch model. - - Returns: - A tuple of the model args + appended inputs and kwargs. - """ - return (*model_args, *self.inputs), model_kwargs - - -class ConvertComplexToRealRepresentationInputStep(InputAdaptStep): - """Convert complex dtype tensors to real representation tensors. - - ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors - to real representation tensors (i.e., float dtype tensors with an extra dimension - representing the real and imaginary parts of the complex number). - - """ - - def apply( - self, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> tuple[Sequence[Any], Mapping[str, Any]]: - """Convert complex tensors to float tensors. - - Args: - model_args: The model args. - model_kwargs: The model kwargs. - model: The PyTorch model. - - Returns: - A tuple of the model args and kwargs. - """ - return ( - tuple( - torch.view_as_real(arg.resolve_conj()) - if isinstance(arg, torch.Tensor) and arg.is_complex() - else arg - for arg in model_args - ), - model_kwargs, - ) - - -class RemoveNoneInputStep(InputAdaptStep): - """Remove `None` from arguments. - - This adapt step assumes ``model_kwargs`` is empty. It also assumes ``model_args`` - is flattened, i.e. it does not check `None` inside nested collections. - """ - - def apply( - self, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> tuple[Sequence[Any], Mapping[str, Any]]: - """Remove `None` from arguments. - - Args: - model_args: The model args. - model_kwargs: The model kwargs. - model: The PyTorch model. - - Returns: - A tuple of the model args and kwargs. - - Raises: - ValueError: If `model_kwargs` is not empty. - """ - assert not model_kwargs - return tuple(arg for arg in model_args if arg is not None), {} - - -class RemoveNonTensorInputStep(InputAdaptStep): - """Remove the non-tensor input arguments. - - Dynamo does not support non-tensor input arguments (https://github.com/pytorch/pytorch/issues/99534). - - Specifically, it does put the input into graph with an empty node, but consumed by no ones. - The concrete value is embedded into the graph as a constant arg of a target node. Meta - suggests in this case that one should rewrite the model code to make it tensor if the - input value is supposed to change at runtime. We might need to further investigate - the feasibility of that suggestion. - - For example, - - def func(x, b=1.0): - y = x + b - z = y.relu() - return (y, z) - - x = torch.randn(1, 1, 2, dtype=torch.float32) - gm_fun, _ = dynamo.export(func, x, b=8.0, aten_graph=True, tracing_mode="real") - - # class GraphModule(torch.nn.Module): - # def forward(self, x, b): - # arg0: f32[1, 1, 2], arg1, = fx_pytree.tree_flatten_spec(([x, b], {}), self._in_spec) - # # File: path/to/pytorch/test_constant_input.py:5, code: y = x + b - # add_tensor: f32[1, 1, 2] = torch.ops.aten.add.Tensor(arg0, 8.0); arg0 = None - - # # File: path/to/pytorch/test_constant_input.py:6, code: z = y.relu() - # relu_default: f32[1, 1, 2] = torch.ops.aten.relu.default(add_tensor) - # return pytree.tree_unflatten([add_tensor, relu_default], self._out_spec) - - Empty torch.fx.Node input leading to a mismatched number of input with PyTorch, as - it's ignored in ONNX graph. Thus, we delete the useless input here. - - """ - - def apply( - self, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> tuple[Sequence[Any], Mapping[str, Any]]: - """Remove Constant from arguments. - - Args: - model_args: The model args. - model_kwargs: The model kwargs. - model: The PyTorch model. - - Returns: - A tuple of the model args and kwargs. - - Raises: - ValueError: If `model_kwargs` is not empty. - """ - assert not model_kwargs - return ( - tuple( - arg - for arg in model_args - if not isinstance(arg, (int, float, bool, str)) - ), - {}, - ) - - -class FlattenInputWithTreeSpecValidationInputStep(InputAdaptStep): - """Flatten nested collection types and return a flat list of elements. - - ONNX can't represent collection types (e.g., dictionary, tuple of tuple of tensor, - etc). - - This class stores the `SpecTree` output produced when `adapt` was called the first - time. It then validates the `SpecTree` output produced from later `adapt` calls. - """ - - _spec: pytree.TreeSpec | None = None - - def apply( - self, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> tuple[Sequence[Any], Mapping[str, Any]]: - """Flatten the model args and kwargs and validate the `SpecTree` output. - - Args: - model_args: The model args. - model_kwargs: The model kwargs. - model: The PyTorch model. - - Returns: - A tuple of the flattened model args and kwargs. The kwargs is empty, because - they are flattened and merged into the args. - - Raises: - ValueError: If the `SpecTree` output produced from the current `model_outputs` - is not identical to the `SpecTree` output produced from the first - `model_outputs` that was passed to this method. - """ - flattened_args, spec = pytree.tree_flatten((model_args, model_kwargs)) - if self._spec is None: - self._spec = spec - else: - _assert_identical_pytree_spec( - self._spec, - spec, - error_message="Model inputs incompatible with the format that was exported. ", - ) - return flattened_args, {} - - -class FlattenOutputStep(OutputAdaptStep): - """Flatten nested collection types and return a flat list of elements. - - ONNX can't represent collection types (e.g., dictionary, tuple of tuple of tensor, - etc). - - NOTE: Ideally we would want to use ``FlattenOutputWithTreeSpecValidationOutputStep``, such - that `SpecTree` can be validate for new model outputs. However, this is not possible - currently because we never have access to real PyTorch model outputs during export. - Only traced outputs may be available, but they are not an accurate reflection of the - original PyTorch model outputs format as they are typically in their own unique format, - depending on the tracing strategy. - """ - - def apply( - self, - model_outputs: Any, - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> Sequence[Any]: - """Flatten the model outputs. - - Args: - model_outputs: The model outputs to flatten. - model: The PyTorch model. - - Returns: - A tuple of the flattened model outputs. - """ - return pytree.tree_leaves(model_outputs) - - -class ConvertComplexToRealRepresentationOutputStep(OutputAdaptStep): - """Convert complex dtype tensors to real representation tensors. - - ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors - to real representation tensors (i.e., float dtype tensors with an extra dimension - representing the real and imaginary parts of the complex number). - - """ - - def apply( - self, - model_outputs: Any, - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> Any: - """Convert float tensors to complex tensors. - - Args: - model_output: The model output. - model: The PyTorch model. - - Returns: - A tuple of the model output. - """ - return [ - torch.view_as_real(output.resolve_conj()) - if isinstance(output, torch.Tensor) and torch.is_complex(output) - else output - for output in model_outputs - ] - - -class FlattenOutputWithTreeSpecValidationOutputStep(OutputAdaptStep): - """Same as ``FlattenOutputStep``, with additional `TreeSpec` validation. - - This class stores the `SpecTree` output produced when `adapt` was called the first - time. It then validates the `SpecTree` output produced from later `adapt` calls. - """ - - _spec: pytree.TreeSpec | None = None - - def apply( - self, - model_outputs: Any, - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> Sequence[Any]: - """Flatten the model outputs and validate the `SpecTree` output. - - Args: - model_outputs: The model outputs to flatten. - model: The PyTorch model. - - Returns: - flattened_outputs: The flattened model outputs. - - Raises: - ValueError: If the `SpecTree` output produced from the current `model_outputs` - is not identical to the `SpecTree` output produced from the first - `model_outputs` that was passed to this method. - """ - flattened_outputs, spec = pytree.tree_flatten(model_outputs) - if self._spec is None: - self._spec = spec - else: - _assert_identical_pytree_spec( - self._spec, - spec, - error_message="Model outputs incompatible with the format that was exported. ", - ) - return flattened_outputs - - -class PrependParamsBuffersConstantAotAutogradInputStep(InputAdaptStep): - """Prepend model parameters, buffers and constants to the user input. - - :func:`torch.export.export` lifts model parameters, buffers and constants as model input, thus, they - must be added to the user input before the model is executed. - - Args: - model: The PyTorch model with embedded parameters and buffers. - """ - - def apply( - self, - model_args: Sequence[Any], - model_kwargs: Mapping[str, Any], - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> tuple[Sequence[Any], Mapping[str, Any]]: - """Convert complex tensors to float tensors. - - Args: - model_args: The model args. - model_kwargs: The model kwargs. - model: The PyTorch model. - - Returns: - A tuple of the model args and kwargs. - """ - ordered_params = tuple( - model.state_dict[name] # type: ignore[union-attr,index] - for name in model.graph_signature.parameters # type: ignore[union-attr] - ) - non_persistent_buffers = set(model.graph_signature.non_persistent_buffers) # type: ignore[arg-type, union-attr] - ordered_buffers = [] - for name in model.graph_signature.buffers: # type: ignore[union-attr] - if name in non_persistent_buffers: - ordered_buffers.append(model.constants[name]) # type: ignore[index, union-attr] - else: - ordered_buffers.append(model.state_dict[name]) # type: ignore[union-attr,index] - ordered_constant_tensors = tuple( - model.constants[fqn] # type: ignore[union-attr,index] - for fqn in model.graph_signature.lifted_tensor_constants # type: ignore[union-attr] - ) - - # NOTE: calling convention is first params, then buffers, then args as user supplied them. - # See: torch/_functorch/aot_autograd.py#L1034 - updated_args = ( - *ordered_params, - *ordered_buffers, - *ordered_constant_tensors, - *model_args, - ) - if model_kwargs: - return MergeKwargsIntoArgsInputStep().apply( - updated_args, model_kwargs, model=model - ) - return updated_args, {} - - -class PrependParamsAndBuffersAotAutogradOutputStep(OutputAdaptStep): - """Prepend model's mutated buffers to the user output. - - :func:`torch.export.export` lifts model's mutated buffers as outputs, thus, they - must be added to the user output after the model is executed. - - Args: - model: The PyTorch model with mutated buffers. - """ - - def apply( - self, - model_outputs: Any, - model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None, - ) -> Sequence[Any]: - """Flatten the model outputs and validate the `SpecTree` output. - - Args: - model_outputs: The model outputs to flatten. - model: The PyTorch model. - - Returns: - flattened_outputs: The flattened model outputs. - """ - - assert isinstance(model, torch_export.ExportedProgram), ( - "'model' must be torch_export.ExportedProgram" - ) - ordered_buffers = tuple( - model.state_dict[name] - if name in model.state_dict - else model.constants[name] - for name in model.graph_signature.buffers_to_mutate.values() - ) - - # NOTE: calling convention is first mutated buffers, then outputs args as model returned them. - updated_outputs = (*ordered_buffers, *model_outputs) - return updated_outputs diff --git a/torch/onnx/_internal/onnxruntime.py b/torch/onnx/_internal/onnxruntime.py deleted file mode 100644 index b994328fcdd8..000000000000 --- a/torch/onnx/_internal/onnxruntime.py +++ /dev/null @@ -1,1260 +0,0 @@ -# mypy: allow-untyped-defs -import dataclasses -import importlib -import logging -import os -from collections.abc import Mapping, Sequence -from typing import Any, Callable, Final, Optional, TYPE_CHECKING, Union -from typing_extensions import TypeAlias - -import torch -import torch._C -import torch._ops -import torch._prims.executor -import torch.fx -import torch.onnx._internal._lazy_import -from torch._subclasses.fake_tensor import FakeTensor -from torch.fx._compatibility import compatibility -from torch.fx.passes.fake_tensor_prop import FakeTensorProp -from torch.fx.passes.operator_support import OperatorSupport -from torch.fx.passes.tools_common import CALLABLE_NODE_OPS -from torch.utils import _pytree - - -if TYPE_CHECKING: - import onnx - import onnxruntime - from onnxruntime.capi import _pybind_state as ORTC - - import torch.onnx - import torch.onnx._internal - import torch.onnx._internal._exporter_legacy - import torch.onnx._internal.fx.decomposition_table - import torch.onnx._internal.fx.passes # noqa: TCH004 - - -_SUPPORT_ONNXRT: Optional[bool] = None - -__all__ = [ - "is_onnxrt_backend_supported", - "torch_compile_backend", - "OrtExecutionProvider", - "OrtBackendOptions", - "OrtBackend", -] - - -def is_onnxrt_backend_supported() -> bool: - """Returns ``True`` if ONNX Runtime dependencies are installed and usable - to support TorchDynamo backend integration; ``False`` otherwise. - - Example:: - - # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) - >>> import torch - >>> if torch.onnx.is_onnxrt_backend_supported(): - ... @torch.compile(backend="onnxrt") - ... def f(x): - ... return x * x - ... print(f(torch.randn(10))) - ... else: - ... print("pip install onnx onnxscript onnxruntime") - ... - """ - global _SUPPORT_ONNXRT - - if _SUPPORT_ONNXRT is None: - # `onnxruntime` might import a lot of other runtime packages, - # e.g. apex, deepspeed, transformers. - # So lazy-importing onnxruntime to avoid possible circular import. - try: - importlib.import_module("onnxruntime") - importlib.import_module("onnxruntime.capi._pybind_state") - - # This is not use directly in DORT but needed by underlying exporter, - # so we still need to check if it exists. - importlib.import_module("onnxscript") - - import torch.onnx # noqa: F401 - import torch.onnx._internal # noqa: F401 - import torch.onnx._internal._exporter_legacy # noqa: F401 - from torch.onnx._internal.fx import ( # noqa: F401 - decomposition_table, - fx_onnx_interpreter, - passes, - type_utils, - ) - - _SUPPORT_ONNXRT = True - except ImportError: - _SUPPORT_ONNXRT = False - - return _SUPPORT_ONNXRT - - -_dumped_onnx_model: dict[str, int] = {} - - -def _dump_onnx_model( - model_string: bytes, graph_module: Optional[torch.fx.GraphModule] = None -) -> str: - """Stores the onnx model into a file. - The name is "{ONNXRT_DUMP_PATH}{N}.onnx" - where *N* is the number of files already stored with - this prefix. - If graph_module is not None, the graph is stored as a string with - the same filename except the extension (.txt). - """ - prefix = os.environ.get("ONNXRT_DUMP_PATH", None) - if not prefix: - return "" - n = _dumped_onnx_model.get(prefix, -1) + 1 - filename = f"{prefix}{n}.onnx" - with open(filename, "wb") as f: - f.write(model_string) - _dumped_onnx_model[prefix] = n - if graph_module is not None: - filename_txt = f"{prefix}{n}.txt" - with open(filename_txt, "w", encoding="utf-8") as f: - f.write(str(graph_module.graph)) - return filename - - -def _infer_default_eps() -> Sequence[str]: - # TODO: select a good default based on the capabilities of the host - # e.g. DML on Windows, etc. - return ["CPUExecutionProvider"] - - -def _nvtx_range_push(name: str): - """If PyTorch is installed with CUDA support, this starts NVTX range. - - Check torch.cuda.nvtx.range_push's document for more details. - """ - if torch.cuda.is_available(): - torch.cuda.nvtx.range_push(name) - - -def _nvtx_range_pop(): - """If PyTorch is installed with CUDA support, this terminates NVTX range. - - Check torch.cuda.nvtx.range_pop's document for more details. - """ - if torch.cuda.is_available(): - torch.cuda.nvtx.range_pop() - - -def _get_ort_device_type(device_type: str): - from onnxruntime.capi import _pybind_state as ORTC - - if device_type == "cuda": - return ORTC.OrtDevice.cuda() - if device_type == "cpu": - return ORTC.OrtDevice.cpu() - # ort pytorch device is mapped to NPU OrtDevice type - if device_type == "maia": - return ORTC.OrtDevice.npu() - raise ValueError("Unsupported device type: " + device_type) - - -logger = logging.getLogger(__name__) -# Uncomment the following lines to print out development info. -# logging.basicConfig(level=logging.WARNING) -# logger.setLevel(logging.WARNING) - - -class OrtOperatorSupport(OperatorSupport): - """Operator support for ONNXRuntime backend. - - It has two-level of support decision. One is via support_dict and the other one - is via extra_support_dict. The logic of using support_dict is implemented in - OrtOperatorSupport and extra_support_dict is used by OperatorSupport.is_node_supported. - """ - - def __init__(self, support_dict: set[Any], extra_support_dict: dict[str, Any]): - # Use extra_support_dict[op_name] = None to indicate - # we support op_name with all input types. Otherwise, - # see support_dict (type: SupportDict) in operator_support.py - # for specifying supported types. - super().__init__(extra_support_dict) - self._onnx_support_dict = support_dict - - def is_node_supported( - self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node - ) -> bool: - # OperatorSupport.is_node_supported returns True for non-callable nodes. - # Since ORT can't execute them, we return False here to override the base - # behavior. - if node.op not in CALLABLE_NODE_OPS: - return False - # This is the and the only place to decide if aten op is supported. - if node.op == "call_function" and node.target in self._onnx_support_dict: - logger.info( - "support_dict supports node.target: %s (type: %s)", - node.target, - type(node.target), - ) - return True - # If node.target is not in support_dict, we still want to check if torch.jit.script - # can convert it to ONNX equivalence. Let's use base mechanism to do this. - # See extra_support_dict for supported ops. - if super().is_node_supported(submodules, node): - logger.info( - "extra_support_dict supports node.target: %s (type: %s)", - node.target, - type(node.target), - ) - return True - logger.warning( - "support_dict and extra_support_dict don't support node.target: %s (type: %s)", - node.target, - type(node.target), - ) - return False - - -def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None: - """ - In torch.fx.Graph, placeholder is a special assignment node. If it's not - executed in the beginning, it could overwrite values computed by upstream - nodes. - """ - - graph = graph_module.graph - placeholders = [] - first_not_placeholder = None - for node in graph.nodes: - if node.op == "placeholder": - placeholders.append(node) - if first_not_placeholder is None and node.op != "placeholder": - first_not_placeholder = node - if first_not_placeholder is None: - return - for placeholder in placeholders: - first_not_placeholder.prepend(placeholder) - - -def _infer_ep_from_device(*args) -> tuple[str, ...]: - """Return the first valid device (i.e., GPU or CPU) in argument list.""" - eps = [] - for arg in args: - if hasattr(arg, "device"): - device = arg.device - if device.type == "cuda": - eps.append("CUDAExecutionProvider") - elif device.type == "cpu": - eps.append("CPUExecutionProvider") - return tuple(eps) - - -def _extract_graph_module_inputs(graph_module: torch.fx.GraphModule) -> tuple[Any, ...]: - placeholders = [] - for node in graph_module.graph.nodes: - if node.op == "placeholder": - if hasattr(node, "meta") and "val" in node.meta: - assert isinstance(node.meta["val"], torch.Tensor) - placeholders.append(node) - return tuple(placeholders) - - -def _extract_graph_module_outputs(graph_module: torch.fx.GraphModule) -> Any: - """Collect "val" fields from outputs metadata in this torch.fx.GraphModule.""" - for node in graph_module.graph.nodes: - if node.op == "output": - # Output node is unique. Let's retrieve output values from - # this node's input list. And then just return. - return node.args[0] - raise ValueError("No output node found in this torch.fx.GraphModule.") - - -def _infer_ep_from_graph_module(graph_module: torch.fx.GraphModule) -> tuple[str, ...]: - """Return the all valid devices (i.e., GPU or CPU) among outputs of this torch.fx.GraphModule.""" - flattened_output_args, _ = _pytree.tree_flatten( - _extract_graph_module_outputs(graph_module) - ) - # Output arguments with example value (type: torch.Tensor) in the `graph_module`. - selected_output_args = [ - output_arg.meta["val"] - for output_arg in flattened_output_args - # output_arg must have tensor for its device information. - # Otherwise, skip it. - if (hasattr(output_arg, "meta") and "val" in output_arg.meta) - ] - return _infer_ep_from_device(*selected_output_args) - - -def _sort_eps(eps: tuple[str, ...]) -> tuple[str, ...]: - """Sort execution providers in eps based on pre-set priority.""" - - def get_execution_provider_priority(ep: str) -> int: - if ep == "CPUExecutionProvider": - # Lowest priority. - return 2 - if ep == "CUDAExecutionProvider": - # Higher priority than CPU but lower than - # other specialized EPs. - return 1 - # Highest priority. - return 0 - - unique_eps = set(eps) - return tuple(sorted(unique_eps, key=get_execution_provider_priority, reverse=True)) - - -def _get_onnx_devices( - values: tuple[ - Union[ - torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool - ], - ..., - ], -) -> tuple["ORTC.OrtDevice", ...]: - from onnxruntime.capi import _pybind_state as ORTC - - def _device_id_or_zero(device_id: int) -> int: - return device_id or 0 - - def _map_tensor_or_sym_to_device( - value: Union[ - torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool - ], - ) -> int: - if isinstance(value, torch.Tensor): - return ORTC.OrtDevice( - _get_ort_device_type(value.device.type), - ORTC.OrtDevice.default_memory(), - _device_id_or_zero(value.device.index), - ) - elif isinstance( - value, (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool) - ): - return ORTC.OrtDevice( - _get_ort_device_type("cpu"), ORTC.OrtDevice.default_memory(), 0 - ) - else: - raise ValueError("Unsupported value type: " + str(type(value))) - - if len(values) > 0: - ort_devices = tuple(_map_tensor_or_sym_to_device(value) for value in values) - return ort_devices - else: - return (_map_tensor_or_sym_to_device(1),) - - -def _get_ortvalues_from_torch_tensors( - tensors: tuple[torch.Tensor, ...], devices: tuple["ORTC.OrtDevice", ...] -) -> tuple[torch.Tensor, ...]: - # TODO(justinchuby): Refactor this function - import numpy as np - from onnxruntime.capi import _pybind_state as ORTC - - torch_dtype_to_numpy_dtype = { - torch.float16: np.float16, - torch.float32: np.float32, - torch.float64: np.float64, - torch.uint8: np.uint8, - torch.int8: np.int8, - torch.int16: np.int16, - torch.int32: np.int32, - torch.int64: np.longlong, - torch.bool: np.bool_, - } - ortvalues = ORTC.OrtValueVector() - ortvalues.reserve(len(tensors)) - dtypes = [] - shapes = [] - data_ptrs = [] - - for tensor in tensors: - dtypes.append(torch_dtype_to_numpy_dtype[tensor.dtype]) - shapes.append(tensor.size()) - data_ptrs.append(tensor.data_ptr()) - ortvalues.push_back_batch(tensors, data_ptrs, dtypes, shapes, devices) - return ortvalues - - -def _to_real_tensor(tensor: FakeTensor) -> torch.Tensor: - if tensor.is_sparse: - raise ValueError("sparse tensor is not yet supported.") - out = torch.empty(tensor.size(), dtype=tensor.dtype, device=tensor.device) - return out - - -def _adjust_scalar_from_fx_to_onnx( - dynamo_value: Union[ - torch.Tensor, - int, - float, - bool, - ], - value_info: "onnx.ValueInfoProto", # type: ignore[name-defined] -) -> torch.Tensor: - """Helper function to wrap PyTorch variables as torch.Tensor""" - if ( - isinstance(dynamo_value, torch.Tensor) - and len(value_info.type.tensor_type.shape.dim) == 0 - and dynamo_value.shape == (1,) - ): - # ONNX expect a scalar with empty shape. - # In contrast, PyTorch usually allows implicit - # conversion between shape=() and shape=(1,). - # - # Below, PyTorch's shape (1,) is reshaped to (). - return torch.squeeze(dynamo_value) - elif isinstance(dynamo_value, int): - return torch.tensor(dynamo_value, dtype=torch.int64) - elif isinstance(dynamo_value, float): - return torch.tensor(dynamo_value, dtype=torch.float32) - elif isinstance(dynamo_value, bool): - return torch.tensor(dynamo_value, dtype=torch.bool) - else: - assert isinstance(dynamo_value, torch.Tensor) - return dynamo_value.contiguous() - - -def _adjust_scalar_from_onnx_to_fx( - tensor: torch.Tensor, - prim_value: Union[ - torch.Tensor, - torch.SymInt, - int, - torch.SymFloat, - float, - torch.SymBool, - bool, - ], -) -> Union[ - torch.Tensor, - int, - float, - bool, -]: - """Helper function to wrap ORT-produced torch.Tensor as PyTorch variables""" - assert isinstance(tensor, torch.Tensor), "ORT's output must be tensor." - if isinstance( - prim_value, - (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool), - ): - # Convert tensor back to scalar to match Dynamo's expectation. - return tensor.item() - return tensor - - -def _run_onnx_session_with_ortvaluevector( - sess: "onnxruntime.InferenceSession", - input_names: tuple[str, ...], - inputs: tuple[torch.Tensor, ...], - input_devices: tuple["ORTC.OrtDevice", ...], - output_names: tuple[str, ...], - outputs: tuple[torch.Tensor, ...], - output_devices: tuple["ORTC.OrtDevice", ...], - preallocate_output: bool, - input_value_infos: tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] - normalized_prim_outputs: tuple[ - Union[ - torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool - ], - ..., - ], -) -> tuple[Union[torch.Tensor, int, float, bool], ...]: - import onnxruntime - from onnxruntime.capi import _pybind_state as ORTC - - _nvtx_range_push("contiguous") - inputs = tuple( - _adjust_scalar_from_fx_to_onnx(arg, value_info) - for arg, value_info in zip(inputs, input_value_infos) - ) - _nvtx_range_pop() - - _nvtx_range_push("push_back_batch") - ort_inputs = _get_ortvalues_from_torch_tensors(inputs, input_devices) - - # preallocate output pytorch Tensors and use the buffers affined to the torch device for the output ortvalue. - # Because the output ortvalue is not allocated and owned by ort, it does not need to convert the output ortvalue - # to torch Tensor transferring the ownership. - if preallocate_output: - pth_outputs = tuple( - _to_real_tensor(t) if isinstance(t, FakeTensor) else t for t in outputs - ) - ort_outputs = _get_ortvalues_from_torch_tensors(pth_outputs, output_devices) - else: - ort_outputs = ORTC.OrtValueVector() - _nvtx_range_pop() - - _nvtx_range_push("run_with_ortvaluevector") - run_options = onnxruntime.RunOptions() - run_options.add_run_config_entry("disable_synchronize_execution_providers", "1") - sess.run_with_ortvaluevector( - run_options, input_names, ort_inputs, output_names, ort_outputs, output_devices - ) - _nvtx_range_pop() - - # Post-processing step: - # wrap ORT's outputs to the schema represented by - # `prim_output` (obtained by running the original - # torch.fx.GraphModule). - if preallocate_output: - # Profile the ORT-to-PyTorch type cast below - _nvtx_range_push("after run_with_ortvaluevector") - # Outputs are stored on pre-allocated torch.Tensors' memory, - # so this case doesn't need to convert ORTValue to torch.Tensor. - pth_outputs = tuple( - _adjust_scalar_from_onnx_to_fx(onnx_output, prim_output) # type: ignore[misc] - for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs) - ) - _nvtx_range_pop() - return pth_outputs - else: - import onnxruntime.training - - # Profile the two ORT-to-PyTorch type casts below - _nvtx_range_push("after run_with_ortvaluevector") - # Map ORTValue to torch.Tensor. - pth_outputs = onnxruntime.training.ortmodule._utils._ortvalues_to_torch_tensor( - ort_outputs - ) - # Change some torch.Tensor to int, float, bool. - pth_outputs = tuple( - _adjust_scalar_from_onnx_to_fx(onnx_output, prim_output) # type: ignore[misc] - for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs) - ) - _nvtx_range_pop() - return pth_outputs - - -def _run_onnx_session_with_fetch( - sess: "onnxruntime.InferenceSession", - input_names: tuple[str, ...], - inputs: tuple[torch.Tensor, ...], - input_devices: tuple["ORTC.OrtDevice", ...], - output_names: tuple[str, ...], - outputs: tuple[torch.Tensor, ...], - output_devices: tuple["ORTC.OrtDevice", ...], - preallocate_output: bool, - input_value_infos: tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] - normalized_prim_outputs: tuple[ - Union[ - torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool - ], - ..., - ], -) -> tuple[Union[torch.Tensor, int, float, bool], ...]: - import onnxruntime - - inputs = tuple( - _adjust_scalar_from_fx_to_onnx(arg, value_info) - for arg, value_info in zip(inputs, input_value_infos) - ) - feed = { - name: onnxruntime.OrtValue.ortvalue_from_numpy(tensor.cpu().numpy()) - for name, tensor in zip(input_names, inputs) - } - ort_outputs = sess.run(output_names, feed) - pth_outputs = tuple( - _adjust_scalar_from_onnx_to_fx( - torch.from_numpy(value), - prim_output, - ) - for value, prim_output in zip(ort_outputs, normalized_prim_outputs) - ) - return pth_outputs - - -def _from_python_type_to_onnx_tensor_element_type(type: type): - """ - Converts a Python type to the corresponding ONNX tensor element type. - For example, `_from_python_type_to_onnx_tensor_element_type(float)` returns - `onnx.TensorProto.FLOAT`. - - Args: - type (type): The Python type to convert. - - Returns: - int: The corresponding ONNX tensor element type. - - """ - import onnx - - _PYTHON_TYPE_TO_ONNX_TENSOR_ELEMENT_TYPE = { - float: onnx.TensorProto.FLOAT, # type: ignore[attr-defined] - int: onnx.TensorProto.INT64, # type: ignore[attr-defined] - bool: onnx.TensorProto.BOOL, # type: ignore[attr-defined] - } - return _PYTHON_TYPE_TO_ONNX_TENSOR_ELEMENT_TYPE.get(type) - - -class OrtExecutionInfoPerSession: - """Information required to execute torch.fx.GraphModule using onnxruntime.InferenceSession""" - - def __init__( - self, - session: "onnxruntime.InferenceSession", - input_names: tuple[str, ...], - input_value_infos: tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] - output_names: tuple[str, ...], - output_value_infos: tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] - input_devices: tuple["ORTC.OrtDevice", ...], - output_devices: tuple["ORTC.OrtDevice", ...], - example_outputs: Union[tuple[torch.Tensor, ...], torch.Tensor], - ): - # Carrier of ONNX model and its executor. - self.session: onnxruntime.InferenceSession = session - # For the ONNX model stored in self.session, self.input_names[i] is the - # name of the i-th positional input. - self.input_names: tuple[str, ...] = input_names - # self.input_name[i]'s type information is stored in self.input_value_infos[i]. - self.input_value_infos: tuple[onnx.ValueInfoProto, ...] = input_value_infos # type: ignore[name-defined] - # Similar to self.input_names, but for outputs. - self.output_names: tuple[str, ...] = output_names - # Similar to self.input_value_infos but for outputs. - self.output_value_infos: tuple[onnx.ValueInfoProto, ...] = output_value_infos # type: ignore[name-defined] - # For the ONNX model stored in self.session, self.input_devices[i] is the - # i-th positional input's device. - self.input_devices: tuple[ORTC.OrtDevice, ...] = input_devices - # Similar to self.input_devices, but for outputs. - self.output_devices: tuple[ORTC.OrtDevice, ...] = output_devices - # This is the outputs of executing the original torch.fx.GraphModule with example inputs - # (i.e., args passed into OrtBackend._ort_acclerated_call). - self.example_outputs: Union[tuple[torch.Tensor, ...], torch.Tensor] = ( - example_outputs - ) - - def is_supported(self, *args): - # TODO(justinchuby): Simplify - import onnx - - _onnx_tensor_element_type_to_torch_dtype = { - onnx.TensorProto.FLOAT: torch.float32, # type: ignore[attr-defined] - onnx.TensorProto.FLOAT16: torch.float16, # type: ignore[attr-defined] - onnx.TensorProto.FLOAT8E5M2: torch.float8_e5m2, # type: ignore[attr-defined] - onnx.TensorProto.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz, # type: ignore[attr-defined] - onnx.TensorProto.FLOAT8E4M3FN: torch.float8_e4m3fn, # type: ignore[attr-defined] - onnx.TensorProto.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz, # type: ignore[attr-defined] - onnx.TensorProto.DOUBLE: torch.float64, # type: ignore[attr-defined] - onnx.TensorProto.BOOL: torch.bool, # type: ignore[attr-defined] - onnx.TensorProto.UINT8: torch.uint8, # type: ignore[attr-defined] - onnx.TensorProto.INT8: torch.int8, # type: ignore[attr-defined] - onnx.TensorProto.INT16: torch.int16, # type: ignore[attr-defined] - onnx.TensorProto.INT32: torch.int32, # type: ignore[attr-defined] - onnx.TensorProto.INT64: torch.int64, # type: ignore[attr-defined] - } - _torch_dtype_to_onnx_tensor_element_type = { - value: key - for key, value in _onnx_tensor_element_type_to_torch_dtype.items() - } - - # Compare the args and the input schema in ONNX model and - # return the first match. - if len(args) != len(self.input_value_infos): - return False - for arg, value_info in zip(args, self.input_value_infos): - if not isinstance(arg, (torch.Tensor, float, int)): - return False - - # Check Python scalars such as int, float, and bool. - if isinstance(arg, (int, float, bool)): - # Map, e.g., float to onnx.TensorProto.FLOAT. - onnx_dtype = _from_python_type_to_onnx_tensor_element_type(type(arg)) - if onnx_dtype != value_info.type.tensor_type.elem_type: - return False - if len(value_info.type.tensor_type.shape.dim) != 0: - return False - continue - - # Check tensor. - onnx_dtype = _torch_dtype_to_onnx_tensor_element_type[arg.dtype] - if onnx_dtype != value_info.type.tensor_type.elem_type: - return False - for dim, onnx_dim in zip(arg.shape, value_info.type.tensor_type.shape.dim): - if isinstance(dim, int) and ( - onnx_dim.dim_value == dim or onnx_dim.dim_param - ): - continue - elif isinstance(dim, torch.SymInt) and onnx_dim.dim_param: - continue - else: - return False - return True - - -@dataclasses.dataclass -class OrtExecutionInfoForAllGraphModules: - def __init__(self) -> None: - # All sessions (and their related information) created by exporting the same GraphModule - # with different inputs. - self.execution_info_per_graph_module: dict[ - torch.fx.GraphModule, list[OrtExecutionInfoPerSession] - ] = {} - - def search_reusable_session_execution_info( - self, graph_module: torch.fx.GraphModule, *args - ): - if graph_module not in self.execution_info_per_graph_module: - return None - # All execution information for ONNX models exported from the same `graph_module` - # with different inputs. - candidates = self.execution_info_per_graph_module[graph_module] - - for candidate in candidates: - if candidate.is_supported(*args): - # Returns the first session that accepts this input schema. - return candidate - # No reusable session found. - return None - - def cache_session_execution_info( - self, graph_module: torch.fx.GraphModule, info: OrtExecutionInfoPerSession - ): - if graph_module not in self.execution_info_per_graph_module: - self.execution_info_per_graph_module[graph_module] = [info] - else: - self.execution_info_per_graph_module[graph_module].append(info) - - -OrtExecutionProvider: TypeAlias = Union[str, tuple[str, Mapping[str, Any]]] -"""Either the name of an ONNX Runtime execution provider as a string or -a 2-tuple of the name and a dictionary of execution provider options. - -Examples:: - - >>> "CPUExecutionProvider" - - >>> ("CUDAExecutionProvider", {"device_id": 3}) - -""" - - -@dataclasses.dataclass(frozen=True) -@compatibility(is_backward_compatible=False) -class OrtBackendOptions: - """Options for constructing an ``OrtBackend``, the ONNX Runtime - backend (``"onnxrt"``) for ``torch.compile``. - - Example:: - - >>> @torch.compile( - ... backend="onnxrt", - ... options=torch.onnx._OrtBackendOptions(...), - ... ) - ... def ort_function(x): - ... return x ** x - """ - - preferred_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None - """An optional sequence of execution providers to be prioritized ahead of any - execution providers that may be inferred (see ``infer_execution_providers``). - """ - - infer_execution_providers: bool = True - """Whether to infer an execution provider from ``torch.device`` bound to inputs or found in the graph.""" - - default_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None - """The default fallback execution providers. If not specified, one will be - be selected based on the host environment (most likely ``"CPUExecutionProvider"``). - """ - - # preallocate_output allows for allocating output torch Tensor buffers and feeding them to InferenceSession - # in order to avoid internal allocation of output buffers in InferenceSession. - # If output ortvalue returned from InferenceSession is allocated internally, - # it needs to be converted to torch Tensor for return, and the torch Tensor should hold the ownership. - # When a custom torch device is used with a custom aten allocator, the conversion from ortvalue to torch Tensor - # should be supported, which is currently done through dlpack. Note that dlpack might not support a custom torch device. - # It can be avoided by allowing for preallocation for output buffers allocated by a custom aten allocator, - # and use the preallocated output buffers for InferenceSession not holding any ownership for them. - # TODO(wschin): Make it to inference session level flag. - # See https://github.com/pytorch/pytorch/issues/106869. - preallocate_output: bool = False - """If ``True``, allocate memory for ONNX Runtime's outputs on the PyTorch side.""" - - use_aot_autograd: bool = True - """Whether to wrap the ``OrtBackend`` with TorchDynamo's aot_autograd backend - to support training (i.e., backward graphs are also sent to ``OrtBackend``). - - Symbolic execution is used to capture the forward pass and backward passes as a single graph. - Then, a selected graph partition algorithm (``min_cut_rematerialization_partition``) is used - to split the entire graph into forward sub-graph and backward sub-graph. Finally, both - sub-graphs are compiled by ``OrtBackend``. - """ - - ort_session_options: Optional["onnxruntime.SessionOptions"] = None - """Options for the ``onnxruntime.InferenceSession`` used by the ``OrtBackend``.""" - - pre_ort_model_transforms: Optional[ # type: ignore[name-defined] - Sequence[Callable[["onnx.ModelProto"], None]] - ] = None - """A list of graph transforms to be applied to the ONNX model before it - is fed to ONNXRuntime's InferenceSession.""" - - -@compatibility(is_backward_compatible=False) -class OrtBackend: - """A backend compiles (sub-)graphs in torch.fx.GraphModule to onnxruntime.InferenceSession calls. - - The compiler entry point is OrtBackend.compile, which - 1. partitions the original graph into supported sub-graphs (type: torch.fx.GraphModule) and unsupported - sub-graphs. - 2. For each supported sub-graph, it replaces its _wrapped_call function with _ort_accelerated_call. - 3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph. - """ - - def __init__(self, options: Optional[OrtBackendOptions] = None): - from onnxruntime.capi import _pybind_state as ORTC - - import torch.onnx - import torch.onnx._internal._exporter_legacy - import torch.onnx._internal.fx.decomposition_table - - self._options: Final = OrtBackendOptions() if options is None else options - - # options.export_options contains information shared between exporter and DORT. - # For example, they should use the same decomposition table when - # 1. capturing FX graph in torch.compile (see how we create aot_ort in register_backend.py) - # 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model - # (see onnxfunction_dispatcher passed to FxOnnxInterpreter.run below). - # - # Convert user-facing option to internal option used by ONNX exporter - # to access required information. - # Some useful fields: - # - Decomposition table for decomposing FX operators in exporter is - # self._resolved_onnx_exporter_options.decomposition_table. - # - self._resolved_onnx_exporter_options.onnx_registry records what - # aten/prim ops are supported by exporter and their exporters (type: callable). - self._resolved_onnx_exporter_options = ( - torch.onnx._internal._exporter_legacy.ResolvedExportOptions() - ) - - # Given DORT's computation flow: - # 1. OrtOperatorSupport uses support_dict and extra_support_dict to select operators - # and send them to DORT. - # 2. Then, DORT exports the selected sub-graphs into ONNX. - # 3. Finally DORT calls ORT to do the computation. - # OrtOperatorSupport and create_onnx_friendly_decomposition_table(...) - # must use the same support_dict. If the support_dict here contains something not - # supported by exporter, exporter will fails in step 2 since the selected graphs may - # contains unsupported operators such as aten::_who_you_are. - # This restriction is automatically done since DORT and exporter shares the same - # self._resolved_onnx_exporter_options. - support_dict = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table( - self._resolved_onnx_exporter_options.onnx_registry - ) - - extra_support_dict: dict[str, Any] = { - "getattr": None, - # To send operator.getitem to ORT, add the corresponding string - # recognized by PyTorch's OperatorSupport class. - "_operator.getitem": None, - # To send operator.mul to ORT, add the corresponding string - # recognized by PyTorch's OperatorSupport class. - "_operator.mul": None, - "_operator.add": None, - "_operator.sub": None, - } - - self._supported_ops = OrtOperatorSupport(support_dict, extra_support_dict) - # TODO(wschin): this is a naive implementation of cache without proper guard - # See https://github.com/pytorch/pytorch/issues/106868. - self._partitioner_cache: dict[torch.fx.GraphModule, torch.fx.GraphModule] = {} - # Conceptually, this filed is a 2-layer dictionary - # GraphModule 0 - # ONNX Model 0 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession) - # ONNX Model 1 - # ... - # GraphModule 1 - # ONNX Model 2 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession) - # ONNX Model 3 - # ... - # ... - # , which caches all previous compilation result so that we can reuse them. - # ONNX Model 0 and 1 are exported from the same GraphModule 0 but with different inputs - # (e.g., tensors with different ranks). GraphModule 0 and GraphModule 1 are different - # graphs captured by Dynamo and sent to OrtBackend.compile. - self._all_ort_execution_info = OrtExecutionInfoForAllGraphModules() - - self._assert_allclose_to_baseline = False - - self.execution_count = 0 - - # Function which invokes ORT do to the real computation. - self.run = ( - _run_onnx_session_with_ortvaluevector - if hasattr(ORTC.OrtValueVector, "push_back_batch") - else _run_onnx_session_with_fetch - ) - - def _select_eps( - self, graph_module: torch.fx.GraphModule, *args - ) -> Sequence[tuple[str, Mapping[str, Any]]]: - inferred_eps: tuple[str, ...] = () - if self._options.infer_execution_providers: - if eps_from_args := _infer_ep_from_device(*args): - # If user feeds CUDA tensor as input argument, - # we want to use CUDA EP. - # Thus, `eps_from_args` (deduced from input arguments) - # has highest priority. - inferred_eps = eps_from_args - elif eps_from_graph_module := _infer_ep_from_graph_module(graph_module): - # If there is no EP in input arguments, we deduce EP from - # graph_module's outputs. Those outputs may come from - # FakeTensorProp or Dynamo's built-in symbolic shape inference. - inferred_eps = eps_from_graph_module - - selected_eps = [] - - for ep in ( - *(self._options.preferred_execution_providers or []), - *_sort_eps(inferred_eps), - *(self._options.default_execution_providers or _infer_default_eps()), - ): - if isinstance(ep, str): - ep = (ep, {}) - elif isinstance(ep, tuple) and ep[1] is None: - ep = (ep[0], {}) - if ep is not None and ep not in selected_eps: - selected_eps.append(ep) - - return selected_eps - - def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwargs): - """This function replaces GraphModule._wrapped_call in compiled model. - - The _wrapped_call is the underlying implementation of forward method. Replacing - it means we delegate the computation to _ort_acclerated_call and therefore - onnxruntime.InferenceSession. - """ - import onnxruntime - - from torch.onnx._internal.fx import fx_onnx_interpreter, passes - - cached_execution_info_per_session = ( - self._all_ort_execution_info.search_reusable_session_execution_info( - graph_module, *args - ) - ) - if cached_execution_info_per_session: - onnx_session = cached_execution_info_per_session.session - input_names = cached_execution_info_per_session.input_names - output_names = cached_execution_info_per_session.output_names - input_value_infos = cached_execution_info_per_session.input_value_infos - output_value_infos = cached_execution_info_per_session.output_value_infos - input_devices = cached_execution_info_per_session.input_devices - output_devices = cached_execution_info_per_session.output_devices - prim_outputs = cached_execution_info_per_session.example_outputs - else: - # It's first time seeing such as graph. Let's make a new session - # (type: onnxruntime.InferenceSession) for it. - - graph_module = passes.MovePlaceholderToFront( - graph_module, - ).run() - # Generate reference outputs. They are used to indicate output - # tensors' types and devices when calling ORT. - # - # WARNING: The downstream code should not change prim_outputs and - # this backend should always produces output with schema identical to prim_outputs'. - - if self._resolved_onnx_exporter_options.dynamic_shapes: - # No pre-allocation when dynamic shape is enabled. - self.preallocate_output = False - extracted_outputs = _extract_graph_module_outputs(graph_module) - - def maybe_map_to_meta_val(value): - if hasattr(value, "meta") and "val" in value.meta: - # Select outputs with "val" information. Without "val", - # it's not possible access output_arg.meta["val"].device. - return value.meta["val"] - else: - return value - - prim_outputs = _pytree.tree_map( - maybe_map_to_meta_val, extracted_outputs - ) - else: - try: - prim_outputs = FakeTensorProp(graph_module).propagate( - *args, **kwargs - ) - except Exception: - logger.warning("FakeTensorProb failed for %s", graph_module) - # When FakeTensorProp fails, it is not possible to preallocate output buffers - # because the output shapes are not inferred. - self.preallocate_output = False - - # rethrow FakeTensorProb failure because it is not yet currently handled. - raise - - # Create the object to iterate through the nodes in graph one-by-one - # and calls the corresponding ONNX exporter for each node. - fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter() - # Cast FX variables if they will result schema-mismatch when searching - # for ONNX operator. E.g., add(double_tensor, int_tensor) is fine in PyTorch, - # but ONNX expects add(double_tensor, double_tensor). - graph_module = passes.InsertTypePromotion(graph_module).run() - # Start the per-node exporting process. It's conceptually a for loop - # scanning through the nodes in the graph. - exported = fx_interpreter.run( - fx_graph_module=graph_module, - onnxfunction_dispatcher=self._resolved_onnx_exporter_options.onnxfunction_dispatcher, - ) - # Convert the exported result to ONNX ModelProto. - onnx_model = exported.to_model_proto( - opset_version=self._resolved_onnx_exporter_options.onnx_registry.opset_version, - ) - - # Modify ONNX model using pre-registered graph transforms. - # They are in-place modifications for avoiding unnecessary - # copy of ONNX initializers. - if self._options.pre_ort_model_transforms: - for transform in self._options.pre_ort_model_transforms: - transform(onnx_model) - - onnx_model_bytes = onnx_model.SerializeToString() - if os.environ.get("ONNXRT_DUMP_PATH", None): - # If not empty, environment variable ONNXRT_DUMP_PATH defined the path - # where generated onnx files should be stored. - # This module keeps a global variables keeping track of the - # stored models. - # If ONNXRT_DUMP_PATH="dumped/dumped_model_" - # The first file name will be 'dumped/dumped_model_0.onnx'. - # For every dumped model, a text file 'dumped/dumped_model_0.txt' - # is created as well to contain the string representing the graph_module. - _dump_onnx_model(onnx_model_bytes, graph_module=graph_module) - - # Initialize a ORT session to execute this ONNX model. - # Note that TorchDynamo assumes all inputs/outputs are on the - # same device, but it's subject to change (very likely with - # dynamic shape support), so we add execution providers - # based on the logic in _select_eps: (explicitly preferred EPs, - # EPs inferred from inputs or graph, and the fallback default EP)/ - # - # TODO(wschin): enable external allocators. - # See https://github.com/pytorch/pytorch/issues/106867 - onnx_session = onnxruntime.InferenceSession( - path_or_bytes=onnx_model_bytes, - sess_options=self._options.ort_session_options, - providers=self._select_eps(graph_module, *args), - ) - - # Cache ORT session. It's reused for the same "graph_module". - # Generate ONNX model and extract its input and output names. - input_names = tuple(input.name for input in onnx_model.graph.input) - output_names = tuple(output.name for output in onnx_model.graph.output) - input_devices = _get_onnx_devices(args) - # Cache devices for inputs and outputs. They are used to invoke - # ORT session. Output devices indicate where (e.g., GPU or CPU) - # to store outputs - if isinstance(prim_outputs, tuple): - output_devices = _get_onnx_devices(prim_outputs) - else: - output_devices = _get_onnx_devices((prim_outputs,)) - - input_value_infos = tuple(input for input in onnx_model.graph.input) - output_value_infos = tuple(output for output in onnx_model.graph.output) - - execution_info_per_session = OrtExecutionInfoPerSession( - session=onnx_session, - input_names=input_names, - input_value_infos=input_value_infos, - output_names=output_names, - output_value_infos=output_value_infos, - input_devices=input_devices, - output_devices=output_devices, - example_outputs=prim_outputs, - ) - - self._all_ort_execution_info.cache_session_execution_info( - graph_module, execution_info_per_session - ) - - self.execution_count += 1 - - # ORT always returns a tuple of outputs. If the original output is a tensor, - # ORT output's first element must be extracted and returned. Otherwise, type - # mismatch may happen in downstream computation. - is_single_tensor_output = isinstance(prim_outputs, torch.Tensor) - normalized_prim_outputs = ( - (prim_outputs,) if is_single_tensor_output else prim_outputs - ) - assert isinstance(normalized_prim_outputs, tuple) - assert all( - isinstance(elem, (torch.Tensor, torch.SymInt, int)) - for elem in normalized_prim_outputs - ) - - _nvtx_range_push("run_onnx_session_with_ortvaluevector") - onnx_outputs = self.run( - onnx_session, - input_names, - args, - input_devices, - output_names, - normalized_prim_outputs, - output_devices, - self._options.preallocate_output, - input_value_infos, - normalized_prim_outputs, - ) - _nvtx_range_pop() - - if self._assert_allclose_to_baseline: - # Compute baseline. - baseline_outputs = torch._prims.executor.execute( - graph_module, *args, executor="aten" - ) - normalized_baseline_ouptuts = ( - (baseline_outputs,) if is_single_tensor_output else baseline_outputs - ) - # Ensure every output tensor is close to the corresponding baseline. - for onnx_output, baseline_output in zip( - onnx_outputs, normalized_baseline_ouptuts - ): - torch.testing.assert_close(onnx_output, baseline_output) - return onnx_outputs[0] if is_single_tensor_output else onnx_outputs - - def compile(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphModule: - # Deferred import since CapabilityBasedPartitioner is not decorated with - # @compatibility; importing it at the module level will result in the test - # failing: pytest test/test_fx.py -k test_public_api_surface - # because this module is imported into torch.onnx. - from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner - - # FX graph based partitioning based on ONNX supported ops. - # Given a graph module - # GraphModule0 - # node_0 - # node_1 - # node_2 - # node_3 - # node_4 - # If only node_2 is not supported by ONNX, this graph module will be partitioned into - # GraphModule0 - # GraphModule1 - # node_0 - # node_1 - # node_2 - # GraphModule2 - # node_3 - # node_4 - # by calling CapabilityBasedPartitioner.partition_and_fuse. - # Then, GraphModule1's and GraphModule2's forward method (GraphModule._wrapped_call) - # will be replaced by OrtBackend._ort_accelerated_call to delegate computation to ORT. - if graph_module in self._partitioner_cache: - partitioned_prim_graph_module = self._partitioner_cache[graph_module] - else: - prim_graph_module = graph_module - partitioner = CapabilityBasedPartitioner( - prim_graph_module, - self._supported_ops, - allows_single_node_partition=True, - ) - partitioned_prim_graph_module = partitioner.partition_and_fuse() - self._partitioner_cache[graph_module] = partitioned_prim_graph_module - - # Overriding fused_module's __call__() function with ort_acclerated_call() - # This loop goes through all graph partitions (each of them is an ONNX-representable graph) - # and override their _wrapped_call function with _ort_accelerated_call. - # Inside _ort_accelerated_call, the partition's graph is exported into ONNX and executed by ORT. - for node in partitioned_prim_graph_module.graph.nodes: - # TODO(wschin): use a better way to identify fused submodule - # See https://github.com/pytorch/pytorch/issues/106872. - if node.op == "call_module" and "fused_" in node.name: - fused_module = getattr(partitioned_prim_graph_module, node.name) - # self.ort_acclerated_call is responsible for exporting graph to ONNX, - # creating ORT session, and running ORT session. - fused_module._wrapped_call = self._ort_acclerated_call - - return partitioned_prim_graph_module - - def __call__( - self, graph_module: torch.fx.GraphModule, args - ) -> torch.fx.GraphModule: - """If ``OrtBackendOptions.use_aot_autograd`` is ``True``, the `auto_autograd` compiler - will be invoked, wrapping this ``OrtBackend`` instance's ``compile`` method. Otherwise, - the ``compile`` method is invoked directly.""" - if self._options.use_aot_autograd: - from functorch.compile import min_cut_rematerialization_partition - from torch._dynamo.backends.common import aot_autograd - - return aot_autograd( - fw_compiler=self.compile, - partition_fn=min_cut_rematerialization_partition, - decompositions=self._resolved_onnx_exporter_options.decomposition_table, - )(graph_module, args) - - return self.compile(graph_module, args) - - __instance_cache_max_count: Final = 8 - __instance_cache: Final[list["OrtBackend"]] = [] - - @staticmethod - def get_cached_instance_for_options( - options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None, - ) -> "OrtBackend": - """Returns a possibly cached instance of an ``OrtBackend``. If an existing - backend was created previously through this function with the same options, - it will be returned. Otherwise a new backend will be created, cached, and - returned. - - Note: if ``options`` sets ``ort_session_options``, a new ``OrtBackend`` - will always be returned, since ``onnxruntime.SessionOptions`` cannot - participate in caching.""" - - def reusable(a: OrtBackendOptions, b: OrtBackendOptions): - if ( - a.preferred_execution_providers != b.preferred_execution_providers - or a.infer_execution_providers != b.infer_execution_providers - or a.default_execution_providers != b.default_execution_providers - or a.preallocate_output != b.preallocate_output - or a.use_aot_autograd != b.use_aot_autograd - or a.pre_ort_model_transforms != b.pre_ort_model_transforms - ): - return False - - # onnxruntime.SessionOptions is a pybind11 object, cannot be pickled, - # and holds too much potential state to reasonably check manually; - # ort_session_options is provided at all, the backend does not participate - # in caching. - if a.ort_session_options is not None or b.ort_session_options is not None: - return False - - return True - - if not isinstance(options, OrtBackendOptions): - options = OrtBackendOptions(**(options or {})) - - backend = next( - (b for b in OrtBackend.__instance_cache if reusable(b._options, options)), - None, - ) - - if backend is None: - assert ( - len(OrtBackend.__instance_cache) < OrtBackend.__instance_cache_max_count - ), ( - f"No more than {OrtBackend.__instance_cache_max_count} instances of " - f"{OrtBackend} allowed. Please instantiate `{OrtBackend}` explicitly " - "to pass to `torch.compile`. " - "See https://github.com/pytorch/pytorch/pull/107973#discussion_r1306144795 " - "for discussion." - ) - OrtBackend.__instance_cache.append(backend := OrtBackend(options)) - - return backend - - @staticmethod - def clear_cached_instances(): - OrtBackend.__instance_cache.clear() - - @staticmethod - def get_cached_instances(): - return tuple(OrtBackend.__instance_cache) - - -@compatibility(is_backward_compatible=False) -def torch_compile_backend( - graph_module: torch.fx.GraphModule, - args, - *, - options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None, -): - return OrtBackend.get_cached_instance_for_options(options)(graph_module, args) diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index 469d7a80f77d..0b8e2478ce33 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -513,7 +513,9 @@ def to_slice_input(list_or_value, default_value=None): if is_none_value(list_or_value) and default_value is not None: list_or_value = [default_value] - if isinstance(list_or_value, (list, torch.Tensor)): + if isinstance(list_or_value, torch.Tensor): + return g.op("Constant", value_t=list_or_value.clone().detach()) + elif isinstance(list_or_value, list): return g.op("Constant", value_t=torch.tensor(list_or_value)) rank = symbolic_helper._get_tensor_rank(list_or_value) diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index 65e76634421a..00b3c9c28774 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -117,6 +117,7 @@ def __setstate__(self, state): ) def share_memory(self): + """Calls tensor.share_memory_() on the state sum tensors.""" for group in self.param_groups: for p in group["params"]: state = self.state[p] diff --git a/torch/overrides.py b/torch/overrides.py index f29ffe57e36a..046171ef6c5c 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -820,6 +820,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1, torch.native_dropout: lambda input, p, train: -1, torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1, + torch._fused_rms_norm: lambda input, normalized_shape, weight=None, eps=1e-05: -1, torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1, torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1, torch.native_channel_shuffle: lambda input, groups: -1, @@ -1511,7 +1512,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: Tensor.view: lambda self, shape: -1, Tensor.view_as: lambda self, other: -1, Tensor.zero_: lambda self: -1, - Tensor.__dlpack__: lambda self, stream=None, max_version=None: -1, + Tensor.__dlpack__: lambda self, stream=None, max_version=None, dl_device=None, copy=None: -1, Tensor.__dlpack_device__: lambda self: -1, torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1, } # fmt: skip diff --git a/torch/serialization.py b/torch/serialization.py index 9660b4ec3cbc..61a4acf68415 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1988,7 +1988,7 @@ def _get_offset(key, name, numel): # for a given key. offsets[name] = storage_offset - # Increment current_offset of offset where next zipfile header starts + # Increment current_offset to offset where next zipfile header starts current_offset = storage_offset + numel # add size of data descriptor after payload if numel > 0: @@ -2004,7 +2004,10 @@ def load_tensor(dtype, numel, key, location): if torch._guards.detect_fake_mode(None) is not None: nbytes = numel * torch._utils._element_size(dtype) storage = torch.UntypedStorage(nbytes, device="meta") - storage._checkpoint_offset = zip_file.get_record_offset(name) + if can_calculate_storage_offsets: + storage._checkpoint_offset = _get_offset(key, name, numel) + else: + storage._checkpoint_offset = zip_file.get_record_offset(name) elif _serialization_tls.skip_data: nbytes = numel * torch._utils._element_size(dtype) storage = torch.UntypedStorage(nbytes) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 92ae95bef8d0..85a333a56601 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2456,7 +2456,7 @@ def error_inputs_cat(op_info, device, **kwargs): # error inputs for empty tensors yield ErrorInput(SampleInput([], kwargs={'dim': 1}), - error_regex='non-empty list of Tensors') + error_regex='non-empty list of Tensors', error_type=ValueError) # error inputs for different sizes yield ErrorInput(SampleInput([make_arg((S, S, L, L)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}), diff --git a/torch/testing/_internal/common_mkldnn.py b/torch/testing/_internal/common_mkldnn.py index ffaed6c7e009..44da60a5ad1f 100644 --- a/torch/testing/_internal/common_mkldnn.py +++ b/torch/testing/_internal/common_mkldnn.py @@ -7,9 +7,6 @@ import torch -# Test whether hardware BF32 math mode enabled. It is enabled only on: -# - MKLDNN is available -# - BF16 is supported by MKLDNN def bf32_is_not_fp32(): if not torch.backends.mkldnn.is_available(): return False @@ -18,8 +15,16 @@ def bf32_is_not_fp32(): return True +def tf32_is_not_fp32(): + if not torch.backends.mkldnn.is_available(): + return False + if not torch._C._cpu._is_amx_fp16_supported(): + return False + return True + + @contextlib.contextmanager -def bf32_off(): +def reduced_f32_off(): old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision old_conv_precision = torch.backends.mkldnn.conv.fp32_precision try: @@ -47,19 +52,39 @@ def bf32_on(self, bf32_precision=1e-2): self.precision = old_precision -# This is a wrapper that wraps a test to run this test twice, one with -# allow_bf32=True, another with allow_bf32=False. When running with -# allow_bf32=True, it will use reduced precision as specified by the -# argument -def bf32_on_and_off(bf32_precision=1e-2): - def with_bf32_disabled(self, function_call): - with bf32_off(): +@contextlib.contextmanager +def tf32_on(self, tf32_precision=1e-5): + old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision + old_conv_precision = torch.backends.mkldnn.conv.fp32_precision + old_precision = self.precision + try: + torch.backends.mkldnn.matmul.fp32_precision = "tf32" + torch.backends.mkldnn.conv.fp32_precision = "tf32" + self.precision = tf32_precision + yield + finally: + torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision + torch.backends.mkldnn.conv.fp32_precision = old_conv_precision + self.precision = old_precision + + +# This is a wrapper that wraps a test to run this test three times, one with +# reduced_f32 OFF, the others with reduced_f32 ON (including bf32 ON and tf32 +# ON). When running with reduced_f32 ON, it will use reduced precision (bf16/ +# tf32) as specified by the argument. +def reduced_f32_on_and_off(bf32_precision=1e-2, tf32_precision=1e-5): + def with_reduced_f32_disabled(self, function_call): + with reduced_f32_off(): function_call() def with_bf32_enabled(self, function_call): with bf32_on(self, bf32_precision): function_call() + def with_tf32_enabled(self, function_call): + with tf32_on(self, tf32_precision): + function_call() + def wrapper(f): params = inspect.signature(f).parameters arg_names = tuple(params.keys()) @@ -67,14 +92,19 @@ def wrapper(f): @functools.wraps(f) def wrapped(*args, **kwargs): kwargs.update(zip(arg_names, args)) - cond = bf32_is_not_fp32() + cond = True if "device" in kwargs: cond = cond and (torch.device(kwargs["device"]).type == "cpu") if "dtype" in kwargs: cond = cond and (kwargs["dtype"] == torch.float) - if cond: - with_bf32_disabled(kwargs["self"], lambda: f(**kwargs)) - with_bf32_enabled(kwargs["self"], lambda: f(**kwargs)) + bf32_cond = cond and bf32_is_not_fp32() + tf32_cond = cond and tf32_is_not_fp32() + if bf32_cond or tf32_cond: + with_reduced_f32_disabled(kwargs["self"], lambda: f(**kwargs)) + if bf32_cond: + with_bf32_enabled(kwargs["self"], lambda: f(**kwargs)) + if tf32_cond: + with_tf32_enabled(kwargs["self"], lambda: f(**kwargs)) else: f(**kwargs) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 692b71660071..1b4f03da3dfc 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -3376,6 +3376,8 @@ def wrapper(*args, **kwargs): if strict_mode or should_reset_dynamo: torch._dynamo.reset() + elif torch._dynamo.config.compiled_autograd: + torch._dynamo.compiled_autograd.reset() # Early terminate test if necessary. If using pytest, use the -x flag instead if using_unittest and self._should_stop_test_suite(): diff --git a/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py b/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py index 838c5fd01adf..60c744ac1a84 100644 --- a/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py +++ b/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py @@ -22,7 +22,7 @@ def world_size(self): return TEST_GPU_NUM def init_pg(self, backend="nccl"): - if backend not in ["nccl", "gloo", "mpi"]: + if backend not in ["nccl", "gloo", "mpi", "hccl"]: raise RuntimeError(f"Backend {backend} not supported!") dist.init_process_group( diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 8e9a9a55f677..94bfead8a0c0 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -17,6 +17,8 @@ from torch.distributed.tensor import ( DeviceMesh, distribute_tensor, + DTensor, + init_device_mesh, Placement, Replicate, Shard, @@ -351,7 +353,7 @@ def backend(self) -> str: return backend def build_device_mesh(self) -> DeviceMesh: - return DeviceMesh(self.device_type, list(range(self.world_size))) + return init_device_mesh(self.device_type, (self.world_size,)) def init_pg(self, eager_init) -> None: if "nccl" in self.backend and torch.cuda.device_count() < self.world_size: @@ -403,6 +405,32 @@ def setUp(self) -> None: super().setUp() self._spawn_processes() + def _test_op_on_dtensor(self, op_call, *args, **kwargs) -> None: + """ + This function checks ``op_call(dtensor).full_tensor() == op_call(dtensor.full_tensor())``. + Unlike _test_op where the DTensor sharding is generated by DTensorConverter, + this function takes in DTensor object directly as argument and test the equality + of calling op on full_tensor() and DTensor. + """ + # call full_tensor() on DTensor args/kwargs + args_flattened, args_spec = tree_flatten(args) + full_tensor_args_flattened = tuple( + arg.full_tensor().detach().clone() if isinstance(arg, DTensor) else arg + for arg in args_flattened + ) + full_tensor_args = tree_unflatten(full_tensor_args_flattened, args_spec) + full_tensor_kwargs = { + k: v.full_tensor() if isinstance(v, DTensor) else v + for k, v in kwargs.items() + } + + out_flattened, _ = tree_flatten( + op_call(*full_tensor_args, **full_tensor_kwargs) + ) + d_out_flattened, _ = tree_flatten(op_call(*args, **kwargs)) + d_out_full_tensor_flattened = [dt.full_tensor() for dt in d_out_flattened] + self.assertEqual(out_flattened, d_out_full_tensor_flattened) + # pyre-ignore[2]: def _test_op(self, mesh: DeviceMesh, op_call, *args, **kwargs) -> None: out = op_call(*args, **kwargs) @@ -456,7 +484,7 @@ def device_type(self) -> str: return DEVICE_TYPE def build_device_mesh(self): - return DeviceMesh(self.device_type, list(range(self.world_size))) + return init_device_mesh(self.device_type, (self.world_size,)) def setUp(self) -> None: super().setUp() diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 91a4aaa5728a..8a521d56f5f8 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -16,6 +16,7 @@ from torch._inductor.codecache import CppCodeCache from torch._inductor.custom_graph_pass import CustomGraphModulePass from torch._inductor.codegen.common import ( + get_custom_backend_config_for_device, get_custom_backend_pass_for_device, get_scheduling_for_device, get_wrapper_codegen_for_device, @@ -27,6 +28,7 @@ from torch._inductor.utils import GPU_TYPES, get_gpu_type, is_gpu from torch.utils._helion import has_helion from torch.utils._triton import has_triton +from torch.utils._config_module import ConfigModule from torch.testing._internal.common_device_type import ( get_desired_device_type_test_bases, ) @@ -308,7 +310,8 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype): def patch_inductor_backend( device: str, python_wrapper_codegen: PythonWrapperCodegen = None, - custom_pass: CustomGraphModulePass = None + custom_pass: CustomGraphModulePass = None, + custom_backend_config: ConfigModule = None ): """ Patch the inductor backend for a specific device. @@ -321,6 +324,7 @@ def patch_inductor_backend( original_python_wrapper = get_wrapper_codegen_for_device(device, False) original_cpp_wrapper = get_wrapper_codegen_for_device(device, True) original_custom_pass = get_custom_backend_pass_for_device(device) + original_custom_backend_config = get_custom_backend_config_for_device(device) try: # Register modified backend for the device @@ -329,7 +333,8 @@ def patch_inductor_backend( original_scheduling, python_wrapper_codegen if python_wrapper_codegen is not None else original_python_wrapper, original_cpp_wrapper, - custom_pass if custom_pass is not None else original_custom_pass + custom_pass if custom_pass is not None else original_custom_pass, + custom_backend_config if custom_backend_config is not None else original_custom_backend_config ) yield finally: @@ -339,5 +344,6 @@ def patch_inductor_backend( original_scheduling, original_python_wrapper, original_cpp_wrapper, - original_custom_pass + original_custom_pass, + original_custom_backend_config ) diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 3fab41d82bc4..664994e6fe38 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -68,6 +68,12 @@ class TorchDispatchMode: API self-referential (beware of infinite loops, in this case!) """ + # - When False, custom torch dispatch mode will error out explicitly when a hop + # is called under the mode. + # - When True, custom torch dispatch mode's __torch_dispatch__ will be triggered. + # Mode authors can implement how the mode interacts with higher order operators. + supports_higher_order_operators = False + def __init__(self, _dispatch_key=None): if _dispatch_key is not None: assert isinstance(_dispatch_key, torch._C.DispatchKey) diff --git a/torch/utils/benchmark/README.md b/torch/utils/benchmark/README.md index 4a64b778181f..6fa025e51d37 100644 --- a/torch/utils/benchmark/README.md +++ b/torch/utils/benchmark/README.md @@ -25,7 +25,7 @@ into two broad categories: * `Timer` implements the `blocked_autorange` function which is a mixture of `timeit.Timer.repeat` and `timeit.Timer.autorange`. This function - selects and appropriate number and runs for a roughly fixed amount of time + selects an appropriate number and runs for a roughly fixed amount of time (like `autorange`), but is less wasteful than `autorange` which discards ~75% of measurements. It runs many times, similar to `repeat`, and returns a `Measurement` containing all of the run results. @@ -46,7 +46,7 @@ table will be generated per unique label. may be logically equivalent differ in implementation. Assigning separate sub_labels will result in a row per sub_label. If a sublabel is not provided, `stmt` is used instead. Statistics (such as computing the fastest -implementation) are use all sub_labels. +implementation) use all sub_labels. * `description`: This describes the inputs. For instance, `stmt=torch.add(x, y)` can be run over several values of `x` and `y`. Each pair should be given its diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index c96ce82cf139..718e728c9389 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -4,6 +4,7 @@ from collections.abc import Iterator, Sized from typing import Any, Callable, Optional, TypeVar, Union +import torch from torch.utils.data._utils.collate import default_collate from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper @@ -74,6 +75,7 @@ def __init__( input_col=None, output_col=None, ) -> None: + torch._C._log_api_usage_once("python.data_pipes.map") super().__init__() self.datapipe = datapipe diff --git a/torch/utils/dlpack.py b/torch/utils/dlpack.py index 9a53ff9e84ac..e7aeae1ba3c8 100644 --- a/torch/utils/dlpack.py +++ b/torch/utils/dlpack.py @@ -1,9 +1,10 @@ -from typing import Any +from typing import Any, Optional import torch import enum from torch._C import _to_dlpack as to_dlpack +from torch.types import Device as _Device __all__ = [ "DLDeviceType", @@ -54,7 +55,12 @@ class DLDeviceType(enum.IntEnum): # TODO: add a typing.Protocol to be able to tell Mypy that only objects with # __dlpack__ and __dlpack_device__ methods are accepted. -def from_dlpack(ext_tensor: Any) -> 'torch.Tensor': +def from_dlpack( + ext_tensor: Any, + *, + device: Optional[_Device] = None, + copy: Optional[bool] = None +) -> 'torch.Tensor': """from_dlpack(ext_tensor) -> Tensor Converts a tensor from an external library into a ``torch.Tensor``. @@ -76,6 +82,13 @@ def from_dlpack(ext_tensor: Any) -> 'torch.Tensor': an opaque ``PyCapsule`` instance, typically produced by a ``to_dlpack`` function or method. + device (torch.device or str or None): An optional PyTorch device + specifying where to place the new tensor. If None (default), the + new tensor will be on the same device as ``ext_tensor``. + + copy (bool or None): An optional boolean indicating whether or not to copy + ``self``. If None, PyTorch will copy only if necessary. + Examples:: >>> import torch.utils.dlpack @@ -106,20 +119,36 @@ def from_dlpack(ext_tensor: Any) -> 'torch.Tensor': """ if hasattr(ext_tensor, '__dlpack__'): + # Only populate kwargs if any of the optional arguments are, in fact, not None. Otherwise, + # leave them out, since we might end up falling back to no-extra-kwargs __dlpack__ call. kwargs: dict[str, Any] = {} kwargs["max_version"] = (1, 0) - device = ext_tensor.__dlpack_device__() - # device is either CUDA or ROCm, we need to pass the current + if copy is not None: + kwargs["copy"] = copy + + # Parse the device parameter. + # At this moment, it can either be a torch.device or a str representing + # a torch.device, e.g. "cpu", "cuda", etc. + if device is not None: + if isinstance(device, str): + device = torch.device(device) + assert isinstance(device, torch.device), ( + f"from_dlpack: unsupported device type: {type(device)}" + ) + kwargs["dl_device"] = torch._C._torchDeviceToDLDevice(device) + + ext_device = ext_tensor.__dlpack_device__() + # ext_device is either CUDA or ROCm, we need to pass the current # stream - if device[0] in (DLDeviceType.kDLCUDA, DLDeviceType.kDLROCM): - stream = torch.cuda.current_stream(f'cuda:{device[1]}') + if ext_device[0] in (DLDeviceType.kDLCUDA, DLDeviceType.kDLROCM): + stream = torch.cuda.current_stream(f'cuda:{ext_device[1]}') # cuda_stream is the pointer to the stream and it is a public # attribute, but it is not documented # The array API specify that the default legacy stream must be passed # with a value of 1 for CUDA # https://data-apis.org/array-api/latest/API_specification/array_object.html?dlpack-self-stream-none#dlpack-self-stream-none - is_cuda = device[0] == DLDeviceType.kDLCUDA + is_cuda = ext_device[0] == DLDeviceType.kDLCUDA # Since pytorch is not using PTDS by default, lets directly pass # the legacy stream stream_ptr = 1 if is_cuda and stream.cuda_stream == 0 else stream.cuda_stream @@ -134,6 +163,10 @@ def from_dlpack(ext_tensor: Any) -> 'torch.Tensor': dlpack = ext_tensor.__dlpack__(**kwargs) else: + assert device is None and copy is None, ( + "device and copy kwargs not supported when ext_tensor is " + "already a DLPack capsule." + ) # Old versions just call the converter dlpack = ext_tensor return torch._C._from_dlpack(dlpack) diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index b2053439140b..348e40eb6254 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -755,11 +755,81 @@ def _count_flops(self, func_packet, out, args, kwargs): return out - class _FlopCounterMode(TorchDispatchMode): + supports_higher_order_operators = True + def __init__(self, counter: FlopCounterMode): self.counter = counter + def _execute_with_isolated_flop_counting(self, branch_fn, operands): + """Execute a branch function and capture its FLOP counts without + affecting self.counter.flop_counts + + Args: + branch_fn: The branch function to execute + operands: Arguments to pass to the branch function + + Returns: + Tuple of (result, flop_counts) where result is the branch output + and flop_counts is a copy of the FLOP counts after execution + """ + import copy + checkpointed_flop_counts = copy.copy(self.counter.flop_counts) + with self: + result = branch_fn(*operands) + flop_counts = copy.copy(self.counter.flop_counts) + self.counter.flop_counts = checkpointed_flop_counts + return result, flop_counts + + def _handle_higher_order_ops(self, func, types, args, kwargs): + if func not in {torch.ops.higher_order.cond, }: + return NotImplemented + + # The flop counter for cond counts the upper bound of flops. + # For example, if a matmul is executed 2 times in true branch + # but only 1 time in the false branch, the flop counter will + # record the larger number of flops, i.e. 2 times. + if func is torch.ops.higher_order.cond: + + pred, true_branch, false_branch, operands = args + # Step 1: Count flops for true branch and false branch separately + true_out, true_flop_counts = self._execute_with_isolated_flop_counting( + true_branch, operands + ) + if true_out is NotImplemented: + return NotImplemented + + false_out, false_flop_counts = self._execute_with_isolated_flop_counting( + false_branch, operands + ) + if false_out is NotImplemented: + return NotImplemented + + # Step 2: merge flop counts + all_mod_keys = set(true_flop_counts.keys()) | set(false_flop_counts.keys()) + merged_flop_counts = {} + for outer_key in all_mod_keys: + true_func_counts = true_flop_counts[outer_key] + false_func_counts = false_flop_counts[outer_key] + + merged_func_counts = {} + all_func_keys = set(true_func_counts.keys()) | set(false_func_counts.keys()) + + for func_key in all_func_keys: + true_val = true_func_counts.get(func_key, 0) + false_val = false_func_counts.get(func_key, 0) + merged_func_counts[func_key] = max(true_val, false_val) + + merged_flop_counts[outer_key] = merged_func_counts + + # Step 3: update the counter with merged counts + for outer_key, inner_dict in merged_flop_counts.items(): + self.counter.flop_counts[outer_key].update(inner_dict) + + # It doesn't matter which one we return since true_fn and false_fn return + # output with the same structure. + return true_out + def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} @@ -781,6 +851,9 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): return NotImplemented + if isinstance(func, torch._ops.HigherOrderOperator): + return self._handle_higher_order_ops(func, types, args, kwargs) + # If we don't have func in flop_registry, see if it can decompose if func not in self.counter.flop_registry and func is not torch.ops.prim.device.default: with self: diff --git a/torch/utils/weak.py b/torch/utils/weak.py index 8bf2ba5ed02b..9c7218cb2ad3 100644 --- a/torch/utils/weak.py +++ b/torch/utils/weak.py @@ -3,8 +3,6 @@ import collections.abc as _collections_abc import weakref - -from _weakrefset import _IterationGuard # type: ignore[attr-defined] from collections.abc import Mapping, MutableMapping from weakref import ref @@ -22,6 +20,33 @@ ] +# TODO: make weakref properly thread safe following +# https://github.com/python/cpython/pull/125325 +class _IterationGuard: + # This context manager registers itself in the current iterators of the + # weak container, such as to delay all removals until the context manager + # exits. + # This technique should be relatively thread-safe (since sets are). + + def __init__(self, weakcontainer): + # Don't create cycles + self.weakcontainer = ref(weakcontainer) + + def __enter__(self): + w = self.weakcontainer() + if w is not None: + w._iterating.add(self) + return self + + def __exit__(self, e, t, b): + w = self.weakcontainer() + if w is not None: + s = w._iterating + s.remove(self) + if not s: + w._commit_removals() + + # This file defines a variant of WeakKeyDictionary that overrides the hashing # behavior of the key to use object identity, rather than the builtin # __eq__/__hash__ functions. This is useful for Tensor weak keys, as their